diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index ffdbb46bd7c06..caddc8fbb7381 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -64,10 +64,81 @@ using framework::ir::Node; using GraphNodeVec = std::vector; using GraphNodeSet = std::unordered_set; +// Deal with subgraph's feed input var node: +// create a new input var node and it's feed op node +void AddFeedOpAndVar(const std::unordered_set& feed_vars, + const GraphNodeSet& cluster, + const std::unordered_map& old_op2new_op, + Graph* graph) { + for (auto* old_var : feed_vars) { + // create feed op + OpDesc desc; + desc.SetType("feed"); + desc.SetOutput("Out", {old_var->Name()}); + auto op = graph->CreateOpNode(&desc); + + // create new feed var node (SSAGraph) + auto var = graph->CreateVarNode(old_var->Var()); + + // link feed op and feed var + op->outputs = {var}; + var->inputs = {op}; + + // link feed var to cluster op + for (auto* old_op : old_var->outputs) { + if (cluster.count(old_op)) { + var->outputs.emplace_back(old_op2new_op.at(old_op)); + old_op2new_op.at(old_op)->inputs.emplace_back(var); + } + // Do not need relink old op or old var here, they will be + // fixed in RemoveLinkFromCluster, here we just deal with + // new subgraph's node. + } + } +} + +// Deal with subgraph's parameter var node: +// create a new input var node, it's data will get by scope, +// so it don't need feed op +void AddParamVar(const std::unordered_set& param_vars, + const GraphNodeSet& cluster, + const std::unordered_map& old_op2new_op, + Graph* graph) { + for (auto* old_var : param_vars) { + auto var = graph->CreateVarNode(old_var->Var()); + + for (auto* old_op : old_var->outputs) { + if (cluster.count(old_op)) { + var->outputs.emplace_back(old_op2new_op.at(old_op)); + old_op2new_op.at(old_op)->inputs.emplace_back(var); + } + } + } +} + +// Deal with subgraph's outputs var node: +// create a new output var node and it's fetch op +void AddOutputVar(const std::unordered_set& output_vars, + const GraphNodeSet& cluster, + const std::unordered_map& old_op2new_op, + Graph* graph) { + for (auto* old_var : output_vars) { + auto var = graph->CreateVarNode(old_var->Var()); + + for (auto* old_op : old_var->inputs) { + if (cluster.count(old_op)) { + var->inputs.emplace_back(old_op2new_op.at(old_op)); + old_op2new_op.at(old_op)->outputs.emplace_back(var); + } + } + } +} + // Create new subgraph with and op nodes are cluster nodes, and all // var node are from internal nodes -std::unique_ptr CreateNewSubGraph( - const GraphNodeSet& cluster, const GraphNodeSet& cluster_internals) { +std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, + const GraphNodeSet& cluster_internals, + const GraphNodeSet& cluster_inputs) { // Graph's constructor must has one parameter, and in our code, // the ProgramDesc is useless, so here we pass a temporary object. auto sub_graph = std::make_unique(framework::ProgramDesc()); @@ -84,6 +155,8 @@ std::unique_ptr CreateNewSubGraph( old_var2new_var[var] = sub_node; } + std::unordered_set need_feed_vars; + std::unordered_set param_vars, output_vars; // the subgraph is independently, so here we only need link // to the node in new subgraph, and discard the link to // out-graph. @@ -91,15 +164,36 @@ std::unique_ptr CreateNewSubGraph( for (auto* var : op->inputs) { if (cluster_internals.count(var)) { old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]); + } else if (cluster_inputs.count(var)) { + if (var->Var()->IsParameter()) { + // Parameters have been preserved in scope, compared to feed var, + // param just need add new var and don't need add feed op. + // The var is used for check whether we need preserve the tensor + // when transform paddle scope to CINN scope. + param_vars.insert(var); + } else { + // When the var is subgraph input and the var is not parameter, + // we need add a new feed op to feed the var. + need_feed_vars.insert(var); + } } } for (auto* var : op->outputs) { if (cluster_internals.count(var)) { old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]); + } else { + // Create new output var node to guarantee the independency of + // subgraph. In other words, the subgraph has no connection with + // other graph, even the input graph. + output_vars.insert(var); } } } + AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, sub_graph.get()); + AddParamVar(param_vars, cluster, old_op2new_op, sub_graph.get()); + AddOutputVar(output_vars, cluster, old_op2new_op, sub_graph.get()); + for (auto* var : cluster_internals) { for (auto* op : var->inputs) { if (cluster.count(op)) { @@ -118,10 +212,12 @@ std::unique_ptr CreateNewSubGraph( // This interface is used to classify all variables involved in a cluster into // three types: inputs, outputs, and internals. -// Specially, the internal node is a node that only used by sub-graph, and +// The input node is some subgraph op's input but not any subgraph op's output. +// The output node is some subgraph op's output and some out-graph op's input. +// Specially, the internal node is a node that only used by subgraph, and // out-graph should not using this node at all. -// inputs & outputs & internals == NULL -// inputs | outputs | internals == all graph node +// cluster_inputs & cluster_outputs & cluster_internals == NULL +// cluster_outputs | cluster_internals == all graph op's outputs node void AnalyseClusterVariables(const GraphNodeSet& cluster, GraphNodeSet* cluster_inputs, GraphNodeSet* cluster_outputs, @@ -154,10 +250,6 @@ void AnalyseClusterVariables(const GraphNodeSet& cluster, } } - // if a output node also exists in input list, remove. - for (auto* var_node : *cluster_inputs) { - cluster_outputs->erase(var_node); - } // if a output node also exists in internal list, remove. for (auto* var_node : *cluster_internals) { cluster_outputs->erase(var_node); @@ -206,14 +298,23 @@ void RemoveLinkFromCluster(const GraphNodeSet& cluster, // removing useless link from cluster_inputs to cluster for (auto* var_node : cluster_inputs) { - auto preserved_nodes = get_preserved_ops(var_node->outputs); - var_node->outputs.assign(preserved_nodes.begin(), preserved_nodes.end()); + auto preserved_ops = get_preserved_ops(var_node->outputs); + var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end()); + // According to SSA form, a var node must not be any two op's output, + // and the cluster_inputs var nodes is defined as an out-graph op's + // output, so the cluster_inputs var nodes are not any subgraph op's + // output. Do not reassign input list here. } // removing useless link from cluster to cluster_outputs for (auto* var_node : cluster_outputs) { - auto preserved_nodes = get_preserved_ops(var_node->inputs); - var_node->inputs.assign(preserved_nodes.begin(), preserved_nodes.end()); + auto preserved_ops = get_preserved_ops(var_node->inputs); + var_node->inputs.assign(preserved_ops.begin(), preserved_ops.end()); + + // Note that cluster_outputs var node maybe some subgraph op's input, + // here we need remove them. + preserved_ops = get_preserved_ops(var_node->outputs); + var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end()); } } @@ -272,7 +373,7 @@ void SearchAllSubgraphs(Graph* graph, &cluster_internals); cinn_subgraphs->emplace_back( - CreateNewSubGraph(cluster_set, cluster_internals)); + CreateNewSubGraph(cluster_set, cluster_internals, cluster_inputs)); // replacing subgraph to a new special op node ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs, diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index 883d5c6fbfb39..bf68a2b554b7f 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -54,6 +54,35 @@ inline Node* GetNode(const std::unordered_set& nodes, [&op_name](const Node* node) { return node->Name() == op_name; }); } +inline bool CheckGraphIndependence(const std::unordered_set& nodes) { + auto check_node_ok = [&nodes](Node* n1, Node* n2) -> bool { + if (n1->IsOp() && !n2->IsVar()) { + return false; + } + if (n1->IsVar() && !n2->IsOp()) { + return false; + } + if (nodes.count(n2) == 0) { + return false; + } + return true; + }; + + for (auto node : nodes) { + for (auto in : node->inputs) { + if (!check_node_ok(node, in)) { + return false; + } + } + for (auto out : node->outputs) { + if (!check_node_ok(node, out)) { + return false; + } + } + } + return true; +} + std::unique_ptr BuildNoCinnSubgraph() { ProgramDesc prog; auto g = std::make_unique(prog); @@ -67,6 +96,8 @@ std::unique_ptr BuildNoCinnSubgraph() { VarDesc var1("var1"); VarDesc var2("var2"); + var2.SetPersistable(true); + var2.SetIsParameter(true); VarDesc var3("var3"); VarDesc var4("var4"); @@ -109,6 +140,7 @@ TEST(BuildCinnPassTest, NoCinnSubgraph) { // After search, origin graph should no change ASSERT_EQ(previous_nodes, g->Nodes()); + ASSERT_TRUE(CheckGraphIndependence(g->Nodes())); // After search, there should one cinn subgraph ASSERT_TRUE(cinn_subgraphs.empty()); @@ -119,11 +151,8 @@ std::unique_ptr BuildAllOpSupportCinnGraph() { auto g = std::make_unique(prog); // v1 -- - // | // | --> mul --> v3 -- - // | | // v2 -- | --> add --> v5 --> relu --> v6 - // | // v4 -- OpDesc add_op; @@ -135,6 +164,8 @@ std::unique_ptr BuildAllOpSupportCinnGraph() { VarDesc var1("var1"); VarDesc var2("var2"); + var2.SetPersistable(true); + var2.SetIsParameter(true); VarDesc var3("var3"); VarDesc var4("var4"); VarDesc var5("var5"); @@ -192,6 +223,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { // v4 --| const auto& nodes = g->Nodes(); ASSERT_EQ(nodes.size(), static_cast(5)); + ASSERT_TRUE(CheckGraphIndependence(nodes)); // A new op named kCinnLaunchOp should be added ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); @@ -214,16 +246,34 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ASSERT_FALSE(CheckNodeExisted(nodes, "relu")); // After search, there should has just one cinn subgraph - // mul --> v3 --> add --> v5 --> relu + // feed --> v1 -- + // | --> mul --> v3 -- + // v2 -- | --> add --> v5 --> relu --> v6 + // feed --> v4 -- ASSERT_EQ(cinn_subgraphs.size(), static_cast(1)); const auto& subgraph = cinn_subgraphs.back(); const auto& subnodes = subgraph->Nodes(); - ASSERT_EQ(subnodes.size(), static_cast(5)); + ASSERT_EQ(subnodes.size(), static_cast(11)); + ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); + ASSERT_EQ(CountNode(subnodes, "feed"), 2); + + // No-parameter input should has feed op + auto new_v1 = GetNode(subnodes, "var1"); + ASSERT_EQ(new_v1->inputs.size(), static_cast(1)); + ASSERT_EQ(new_v1->outputs.size(), static_cast(1)); + ASSERT_EQ(new_v1->inputs[0]->Name(), "feed"); + ASSERT_EQ(new_v1->outputs[0]->Name(), "mul"); + + // Parameter input should not has feed op + auto new_v2 = GetNode(subnodes, "var2"); + ASSERT_TRUE(new_v2->inputs.empty()); + ASSERT_EQ(new_v2->outputs.size(), static_cast(1)); + ASSERT_EQ(new_v2->outputs[0]->Name(), "mul"); } std::unique_ptr BuildGraphWithOneCinnSubgraph() { @@ -231,9 +281,7 @@ std::unique_ptr BuildGraphWithOneCinnSubgraph() { auto g = std::make_unique(prog); // fake1 --> v1 -- - // | // | --> mul --> v3 --> relu --> v4 --> fake2 - // | // v2 -- OpDesc fake1_op; @@ -247,6 +295,8 @@ std::unique_ptr BuildGraphWithOneCinnSubgraph() { VarDesc var1("var1"); VarDesc var2("var2"); + var2.SetPersistable(true); + var2.SetIsParameter(true); VarDesc var3("var3"); VarDesc var4("var4"); @@ -299,6 +349,7 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { // v2 -- const auto& nodes = g->Nodes(); ASSERT_EQ(nodes.size(), static_cast(6)); + ASSERT_TRUE(CheckGraphIndependence(nodes)); // A new op named kCinnLaunchOp should be added ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); @@ -312,15 +363,19 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { ASSERT_TRUE(CheckNodeExisted(nodes, "fake2")); // After search, there should has just one cinn subgraph - // mul --> v3 --> relu + // feed --> v1 -- + // | --> mul --> v3 --> relu --> v4 + // v2 -- ASSERT_EQ(cinn_subgraphs.size(), static_cast(1)); const auto& subgraph = cinn_subgraphs.back(); const auto& subnodes = subgraph->Nodes(); - ASSERT_EQ(subnodes.size(), static_cast(3)); + ASSERT_EQ(subnodes.size(), static_cast(7)); + ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); + ASSERT_EQ(CountNode(subnodes, "feed"), 1); } std::unique_ptr BuildGraphWithMultiCinnSubgraph() { @@ -328,9 +383,7 @@ std::unique_ptr BuildGraphWithMultiCinnSubgraph() { auto g = std::make_unique(prog); // fake1 --> v1 -- - // | // | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3 - // | // v2 -- OpDesc fake1_op; @@ -346,6 +399,8 @@ std::unique_ptr BuildGraphWithMultiCinnSubgraph() { VarDesc var1("var1"); VarDesc var2("var2"); + var2.SetPersistable(true); + var2.SetIsParameter(true); VarDesc var3("var3"); VarDesc var4("var4"); VarDesc var5("var5"); @@ -406,6 +461,7 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { // v2 - const auto& nodes = g->Nodes(); ASSERT_EQ(nodes.size(), static_cast(10)); + ASSERT_TRUE(CheckGraphIndependence(nodes)); // A new op named kCinnLaunchOp should be added ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); @@ -424,15 +480,27 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { // and each of subgraphs just has one node. ASSERT_EQ(cinn_subgraphs.size(), static_cast(2)); - // subgraph1: relu + // subgraph1: + // feed --> v4 --> relu --> v5 + // subgraph2: + // feed --> v1 -- + // | --> mul --> v3 + // v2 -- const auto& subgraph1 = cinn_subgraphs[0]; const auto& subnodes1 = subgraph1->Nodes(); - ASSERT_EQ(subnodes1.size(), static_cast(1)); + ASSERT_TRUE(CheckGraphIndependence(subnodes1)); - // subgraph2: mul const auto& subgraph2 = cinn_subgraphs[1]; const auto& subnodes2 = subgraph2->Nodes(); - ASSERT_EQ(subnodes2.size(), static_cast(1)); + ASSERT_TRUE(CheckGraphIndependence(subnodes2)); + + if (CheckNodeExisted(subnodes1, "relu")) { + ASSERT_EQ(subnodes1.size(), static_cast(4)); + ASSERT_EQ(subnodes2.size(), static_cast(5)); + } else { + ASSERT_EQ(subnodes2.size(), static_cast(4)); + ASSERT_EQ(subnodes1.size(), static_cast(5)); + } } } // namespace paddle2cinn