diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 63559e201594a..e4c9dc72128f4 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2516,6 +2516,15 @@ PDNode *patterns::DuplicatedInputs::operator()() { return op; } +PDNode *patterns::DuplicatedOutputs::operator()() { + auto op = pattern->NewNode(op_repr())->assert_is_ops({"split"}); + op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "bfloat16"; + }); + return op; +} + PDNode *patterns::MKLDNNInPlace::operator()() { const std::unordered_set &supported_op_types = { "abs", "gelu", "leaky_relu", "relu", "softmax", "sqrt", "swish", "tanh"}; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 79f1d63a15190..d6400ed6945bf 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1495,6 +1495,15 @@ struct DuplicatedInputs : public PatternBase { PATTERN_DECL_NODE(op); }; +struct DuplicatedOutputs : public PatternBase { + DuplicatedOutputs(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "many_outputs_op") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(op); +}; + // Pattern used for enforcing inplace computation for in-place computation // supporting DNNL ops. softmax, batch_norm and layer_norm struct MKLDNNInPlace : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc index 5f9aefc1e7a0b..f1bd34a5ad4f6 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc @@ -52,7 +52,7 @@ bool IsPermittedOutputName(const std::string& output_name) { } void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, - int* quantize_counter) { + int& quantize_counter) { std::vector input_names; // Find the name of the input linking op to op_in @@ -87,10 +87,10 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, IR_NODE_LINK_TO(op_in, quantize_op); IR_NODE_LINK_TO(quantize_op, quantize_out_node); IR_NODE_LINK_TO(quantize_out_node, op); - (*quantize_counter)++; + quantize_counter++; } -void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { +void AddQuantizes(Graph* g, ir::Node* op, int& quantize_counter) { auto inputs = op->inputs; PADDLE_ENFORCE_GE(inputs.size(), 1, platform::errors::InvalidArgument( @@ -127,7 +127,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { IR_NODE_LINK_TO(inputs[i], quantize_op); IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]); IR_NODE_LINK_TO(quantize_out_nodes[i], op); - (*quantize_counter)++; + quantize_counter++; } op->Op()->SetInput("X", quantize_out_node_names); @@ -136,7 +136,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { // Operators like Concat and Sum have a single input name X, which actually // consists of multiple inputs. Such operators require a different way to find // pattern and add quantize ops. -void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { +void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int& quantize_counter) { GraphPatternDetector gpd; patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(), "duplicated_inputs"}; @@ -151,7 +151,7 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { // Adding quantize ops before all operators except Concat and Sum, which have // already been handled in AddReoderBeforeDuplicatedInputs -void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { +void AddReoderBeforeSingleInputs(ir::Graph* graph, int& quantize_counter) { GraphPatternDetector gpd; patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), "first_bfloat16_ops"}; @@ -169,60 +169,134 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const { int quantize_counter = 0; - AddReoderBeforeDuplicatedInputs(graph, &quantize_counter); - AddReoderBeforeSingleInputs(graph, &quantize_counter); + AddReoderBeforeDuplicatedInputs(graph, quantize_counter); + AddReoderBeforeSingleInputs(graph, quantize_counter); PrettyLogDetail("--- added %d quantize ops before bfloat16 op", quantize_counter); } -void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const { +void AddDequantize(Graph* g, ir::Node* op, ir::Node* op_out, + int& dequantize_counter) { + if (op->Op()->Type() == "prior_box") return; + + // Find the name of the output linking op to op_out + std::vector output_names; + for (auto name : op->Op()->OutputNames()) + for (auto output_name : op->Op()->Output(name)) + if (output_name == op_out->Name() && IsPermittedOutputName(name)) + output_names.push_back(name); + + if (output_names.empty()) return; + + VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); + auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); + + OpDesc deq_desc; + deq_desc.SetType("dequantize"); + deq_desc.SetInput("Input", + std::vector({dequantize_in_node->Name()})); + deq_desc.SetOutput("Output", std::vector({op_out->Name()})); + deq_desc.SetAttr("Scale", 1.0f); + deq_desc.SetAttr("Shift", 0.0f); + auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. + + for (auto name = output_names.begin(); name < output_names.end(); name++) + op->Op()->SetOutput(*name, + std::vector({dequantize_in_node->Name()})); + + UnlinkNodes(op, op_out); + IR_NODE_LINK_TO(op, dequantize_in_node); + IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); + IR_NODE_LINK_TO(dequantize_op, op_out); + + dequantize_counter++; +} + +void AddDequantizes(Graph* g, ir::Node* op, int& dequantize_counter) { + auto outputs = op->outputs; + PADDLE_ENFORCE_GE(outputs.size(), 1, + platform::errors::InvalidArgument( + "OP(%s)'s outputs(%d) must be equal or greater than 1.", + op->Name(), outputs.size())); + PADDLE_ENFORCE_EQ(op->inputs.size(), 1, + platform::errors::InvalidArgument( + "OP(%s)'s inputs(%d) must be equal to 1.", op->Name(), + op->inputs.size())); + + OpDesc deq_desc; + deq_desc.SetType("dequantize"); + + std::vector dequantize_in_nodes(outputs.size()); + std::vector dequantize_in_node_names(outputs.size()); + + for (size_t i = 0; i < outputs.size(); i++) { + VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); + dequantize_in_nodes[i] = g->CreateVarNode(&dequantize_in_desc); + dequantize_in_node_names[i] = dequantize_in_nodes[i]->Name(); + + deq_desc.SetInput("Input", + std::vector({dequantize_in_node_names[i]})); + deq_desc.SetOutput("Output", + std::vector({outputs[i]->Name()})); + + deq_desc.SetAttr("Scale", 1.f); + deq_desc.SetAttr("Shift", 0.0f); + deq_desc.SetAttr("bfloat16", true); + deq_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") + ? op->Op()->GetAttr("data_layout") + : std::string("NCHW")); + auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. + + UnlinkNodes(op, outputs[i]); + IR_NODE_LINK_TO(op, dequantize_in_nodes[i]); + IR_NODE_LINK_TO(dequantize_in_nodes[i], dequantize_op); + IR_NODE_LINK_TO(dequantize_op, outputs[i]); + + dequantize_counter++; + } + + op->Op()->SetOutput("Out", dequantize_in_node_names); +} + +// Operators like split have a single output name Out, which actually +// consists of multiple outputs. Such operators require a different way to find +// pattern and add dequantize ops. +void AddReoderAfterDuplicatedOutputs(ir::Graph* graph, + int& dequantize_counter) { + GraphPatternDetector gpd; + patterns::DuplicatedOutputs duplicated_outputs{gpd.mutable_pattern(), + "duplicated_outputs"}; + duplicated_outputs(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_outputs); + AddDequantizes(g, op, dequantize_counter); + }; + gpd(graph, handler); +} + +// Adding dequantize ops after all operators except split, which has +// already been handled in AddReoderAfterDuplicatedOutputs +void AddReoderAfterSingleOutputs(ir::Graph* graph, int& dequantize_counter) { GraphPatternDetector gpd; patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), "last_bfloat16_ops"}; bfloat16_ops(); - int dequantize_counter = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops); - - if (op->Op()->Type() != "prior_box") { - // Find the name of the output linking op to op_out - std::vector output_names; - for (auto name : op->Op()->OutputNames()) - for (auto output_name : op->Op()->Output(name)) - if (output_name == op_out->Name() && IsPermittedOutputName(name)) - output_names.push_back(name); - - if (output_names.empty()) return; - - VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); - auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); - - OpDesc deq_desc; - deq_desc.SetType("dequantize"); - deq_desc.SetInput("Input", - std::vector({dequantize_in_node->Name()})); - deq_desc.SetOutput("Output", std::vector({op_out->Name()})); - deq_desc.SetAttr("Scale", 1.0f); - deq_desc.SetAttr("Shift", 0.0f); - auto dequantize_op = - g->CreateOpNode(&deq_desc); // OpDesc will be copied. - - for (auto name = output_names.begin(); name < output_names.end(); name++) - op->Op()->SetOutput( - *name, std::vector({dequantize_in_node->Name()})); - - UnlinkNodes(op, op_out); - IR_NODE_LINK_TO(op, dequantize_in_node); - IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); - IR_NODE_LINK_TO(dequantize_op, op_out); - - dequantize_counter++; + GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); + if (op->Op()->Type() != "split") { + AddDequantize(g, op, op_out, dequantize_counter); } }; gpd(graph, handler); +} + +void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const { + int dequantize_counter = 0; + AddReoderAfterDuplicatedOutputs(graph, dequantize_counter); + AddReoderAfterSingleOutputs(graph, dequantize_counter); PrettyLogDetail("--- added %d dequantize ops after bfloat16 op", dequantize_counter); } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc index f620b4c94fe89..877ee71fc2d85 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc @@ -45,7 +45,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("Input", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); op->SetAttr("mkldnn_data_type", mkldnn_data_type); - } else if (type == "concat" || type == "sum") { + } else if (type == "concat" || type == "sum" || type == "split") { op->SetInput("X", inputs); op->SetOutput("Out", outputs); op->SetAttr("mkldnn_data_type", mkldnn_data_type); @@ -117,6 +117,7 @@ TEST(CpuBfloat16Pass, convolution) { bool use_mkldnn = true; int quant_op = 3; int dequant_op = 3; + // each added op consists of 2 nodes int added_nodes = quant_op * 2 + dequant_op * 2; MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes); } @@ -140,6 +141,7 @@ TEST(CpuBfloat16Pass, double_input_ops) { bool use_mkldnn = true; int quant_op = 4; int dequant_op = 3; + // each added op consists of 2 nodes int added_nodes = quant_op * 2 + dequant_op * 2; MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op, added_nodes); @@ -164,11 +166,35 @@ TEST(CpuBfloat16Pass, duplicated_input_ops) { bool use_mkldnn = true; int quant_op = 5; int dequant_op = 3; + // each added op consists of 2 nodes int added_nodes = quant_op * 2 + dequant_op * 2; MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op, added_nodes); } +ProgramDesc BuildProgramDescDuplicatedOutput(bool use_mkldnn) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, use_mkldnn, "float32"); + SetOp(&prog, "split", "Split", {"b"}, {"c", "d"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "transpose2", "Transpose", {"c"}, {"e"}, use_mkldnn, "float32"); + SetOp(&prog, "reshape2", "Reshape", {"d"}, {"f"}, use_mkldnn, "bfloat16"); + + return prog; +} + +TEST(CpuBfloat16Pass, duplicated_output_ops) { + bool use_mkldnn = true; + int quant_op = 2; + int dequant_op = 3; + // each added op consists of 2 nodes + int added_nodes = quant_op * 2 + dequant_op * 2; + MainTest(BuildProgramDescDuplicatedOutput(use_mkldnn), quant_op, dequant_op, + added_nodes); +} + ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) { ProgramDesc prog; for (auto& v : variable_names) { @@ -190,6 +216,7 @@ TEST(CpuBfloat16Pass, double_outputs_ops) { bool use_mkldnn = true; int quant_op = 3; int dequant_op = 3; + // each added op consists of 2 nodes int added_nodes = quant_op * 2 + dequant_op * 2; MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op, added_nodes); diff --git a/paddle/phi/kernels/cpu/split_kernel.cc b/paddle/phi/kernels/cpu/split_kernel.cc index 7b2166eaf11f9..722681fb7bc3f 100644 --- a/paddle/phi/kernels/cpu/split_kernel.cc +++ b/paddle/phi/kernels/cpu/split_kernel.cc @@ -70,4 +70,5 @@ PD_REGISTER_KERNEL(split, int64_t, int, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {}