Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix several bugs for enabling Paddle to train with CINN. #36739

Merged
merged 8 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
ResolveOptionConfliction();

AppendPrintGraphPass("graph_viz_pass", "_original_graph");

#ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass");
AppendPrintGraphPass("graph_viz_pass", "_build_cinn_graph");
}
#endif

AppendPassWithCheck(strategy_.enable_sequential_execution_,
"sequential_execution_pass");
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");
Expand All @@ -74,13 +83,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass");

#ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass");
}
#endif

SetCollectiveContext();
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/paddle2cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector cinn_compiler)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)

cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
Expand Down
227 changes: 104 additions & 123 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ limitations under the License. */
#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"

namespace paddle {
Expand All @@ -40,11 +42,28 @@ using framework::ir::Node;
using GraphNodeVec = std::vector<Node*>;
using GraphNodeSet = std::unordered_set<Node*>;

namespace {
int ExtractOpRole(const GraphNodeSet& cluster) {
std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
for (auto* n : cluster) {
if (n->Op() && n->Op()->HasAttr(attr_name)) {
op_roles.insert(BOOST_GET_CONST(int, n->Op()->GetAttr(attr_name)));
}
}
if (op_roles.size() == 1U) {
return *(op_roles.begin());
} else {
return static_cast<int>(OpRole::kNotSpecified);
}
}

// Deal with subgraph's feed input var node:
// create a new input var node and it's feed op node
void AddFeedOpAndVar(const std::unordered_set<Node*>& feed_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
const std::unordered_map<Node*, Node*>& old_var2new_var,
Graph* graph) {
for (auto* old_var : feed_vars) {
// create feed op
Expand All @@ -53,21 +72,19 @@ void AddFeedOpAndVar(const std::unordered_set<Node*>& feed_vars,
desc.SetOutput("Out", {old_var->Name()});
auto op = graph->CreateOpNode(&desc);

// create new feed var node (SSAGraph)
auto var = graph->CreateVarNode(old_var->Var());
// get new feed var node
auto* var = old_var2new_var.at(old_var);

// link feed op and feed var
op->outputs = {var};
var->inputs = {op};
IR_NODE_LINK_TO(op, var);

// link feed var to cluster op
for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
var->outputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->inputs.emplace_back(var);
IR_NODE_LINK_TO(var, old_op2new_op.at(old_op));
}
// Do not need relink old op or old var here, they will be
// fixed in RemoveLinkFromCluster, here we just deal with
// fixed in RemoveSubGraphFromGraph, here we just deal with
// new subgraph's node.
}
}
Expand All @@ -79,14 +96,14 @@ void AddFeedOpAndVar(const std::unordered_set<Node*>& feed_vars,
void AddParamVar(const std::unordered_set<Node*>& param_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
const std::unordered_map<Node*, Node*>& old_var2new_var,
Graph* graph) {
for (auto* old_var : param_vars) {
auto var = graph->CreateVarNode(old_var->Var());
auto* var = old_var2new_var.at(old_var);

for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
var->outputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->inputs.emplace_back(var);
IR_NODE_LINK_TO(var, old_op2new_op.at(old_op));
}
}
}
Expand All @@ -97,14 +114,14 @@ void AddParamVar(const std::unordered_set<Node*>& param_vars,
void AddOutputVar(const std::unordered_set<Node*>& output_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
const std::unordered_map<Node*, Node*>& old_var2new_var,
Graph* graph) {
for (auto* old_var : output_vars) {
auto var = graph->CreateVarNode(old_var->Var());
auto* var = old_var2new_var.at(old_var);

for (auto* old_op : old_var->inputs) {
if (cluster.count(old_op)) {
var->inputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->outputs.emplace_back(var);
IR_NODE_LINK_TO(old_op2new_op.at(old_op), var);
}
}
}
Expand Down Expand Up @@ -136,6 +153,18 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
}
old_var2new_var[var] = sub_node;
}
for (auto* var : cluster_inputs) {
if (var->Var()) {
auto* sub_node = subgraph->CreateVarNode(var->Var());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿为啥不要var->Var() == nullptrCreateEmptyNode啊?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为子图不需要这种var_desc为null的节点,CINN编译的时候也用不到。而这种空var只是为了加上依赖关系,在修改后的大图上已经有提现。

old_var2new_var[var] = sub_node;
}
}
for (auto* var : cluster_outputs) {
if (var->Var()) {
auto* sub_node = subgraph->CreateVarNode(var->Var());
old_var2new_var[var] = sub_node;
}
}

std::unordered_set<Node*> need_feed_vars;
std::unordered_set<Node *> param_vars, output_vars;
Expand All @@ -144,8 +173,10 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
// out-graph.
for (auto* op : cluster) {
for (auto* var : op->inputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]);
// one output var maybe an input of the cluster
if (cluster_internals.count(var) ||
(cluster_outputs.count(var) && old_var2new_var.count(var))) {
IR_NODE_LINK_TO(old_var2new_var.at(var), old_op2new_op.at(op));
} else if (cluster_inputs.count(var) && var->Var() != nullptr) {
if (var->Var()->IsParameter()) {
// Parameters have been preserved in scope, compared to feed var,
Expand All @@ -162,7 +193,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
}
for (auto* var : op->outputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]);
IR_NODE_LINK_TO(old_op2new_op.at(op), old_var2new_var.at(var));
} else if (cluster_outputs.count(var) && var->Var() != nullptr) {
// Create new output var node to guarantee the independency of
// subgraph. In other words, the subgraph has no connection with
Expand All @@ -172,22 +203,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
}
}

AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, subgraph.get());
AddParamVar(param_vars, cluster, old_op2new_op, subgraph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, subgraph.get());

for (auto* var : cluster_internals) {
for (auto* op : var->inputs) {
if (cluster.count(op)) {
old_var2new_var[var]->inputs.emplace_back(old_op2new_op[op]);
}
}
for (auto* op : var->outputs) {
if (cluster.count(op)) {
old_var2new_var[var]->outputs.emplace_back(old_op2new_op[op]);
}
}
}
AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get());
AddParamVar(param_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get());

return subgraph;
}
Expand Down Expand Up @@ -238,117 +259,77 @@ void AnalyseClusterVariables(const GraphNodeSet& cluster,
}
}

Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const std::string& compilation_key, Graph* graph) {
// add special cinn op
framework::OpDesc special_op_desc;
special_op_desc.SetType(kCinnLaunchOp);
void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, Node* cinn_op_node) {
// add new link from cluster_inputs to cinn_op_node
for (auto* var_node : cluster_inputs) {
IR_NODE_LINK_TO(var_node, cinn_op_node);
}

// add new link from cinn_op_node to cluster_outputs
for (auto* var_node : cluster_outputs) {
IR_NODE_LINK_TO(cinn_op_node, var_node);
}
}

void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const std::string& compilation_key, Graph* graph) {
// Add the cinn launch op
framework::OpDesc cinn_op_desc;
cinn_op_desc.SetType(kCinnLaunchOp);
std::vector<std::string> input_names;
std::for_each(cluster_inputs.begin(), cluster_inputs.end(),
[&input_names](Node* n) {
if (n->Var() != nullptr) {
input_names.emplace_back(n->Name());
}
});
special_op_desc.SetInput("X", input_names);
cinn_op_desc.SetInput("X", input_names);
std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names](Node* n) {
if (n->Var() != nullptr) {
output_names.emplace_back(n->Name());
}
});
special_op_desc.SetOutput("Out", output_names);
special_op_desc.SetAttr(kCompilationKey, compilation_key);
special_op_desc.Flush();
auto* special_op_node = graph->CreateOpNode(&special_op_desc);
special_op_node->inputs.assign(cluster_inputs.begin(), cluster_inputs.end());
special_op_node->outputs.assign(cluster_outputs.begin(),
cluster_outputs.end());
return special_op_node;
}

void AddLinkToSpecialOp(const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
Node* special_op_node) {
// add new link from cluster_inputs to special_op_node
for (auto* var_node : cluster_inputs) {
var_node->outputs.push_back(special_op_node);
}

// add new link from special_op_node to cluster_outputs
for (auto* var_node : cluster_outputs) {
var_node->inputs.push_back(special_op_node);
}
}

void RemoveLinkFromCluster(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs) {
// remove all nodes in cluster
auto get_preserved_ops = [&cluster](const GraphNodeVec& ops) {
GraphNodeVec nodes;
for (auto* op_node : ops) {
if (cluster.find(op_node) == cluster.end()) {
nodes.emplace_back(op_node);
}
}
return nodes;
};

// removing useless link from cluster_inputs to cluster
for (auto* var_node : cluster_inputs) {
auto preserved_ops = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end());
// According to SSA form, a var node must not be any two op's output,
// and the cluster_inputs var nodes is defined as an out-graph op's
// output, so the cluster_inputs var nodes are not any subgraph op's
// output. Do not reassign input list here.
}

// removing useless link from cluster to cluster_outputs
for (auto* var_node : cluster_outputs) {
auto preserved_ops = get_preserved_ops(var_node->inputs);
var_node->inputs.assign(preserved_ops.begin(), preserved_ops.end());

// Note that cluster_outputs var node maybe some subgraph op's input,
// here we need remove them.
preserved_ops = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end());
}
cinn_op_desc.SetOutput("Out", output_names);
cinn_op_desc.SetAttr(kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster));
cinn_op_desc.Flush();
auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc);
// Add new links from or to the the cinn launch op node
AddLinkToCinnOp(cluster_inputs, cluster_outputs, cinn_op_node);
}

// Removing cluster node and internals node from Graph
void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_internals,
Graph* graph) {
for (auto* op_node : cluster) {
graph->RemoveNode(op_node);
}
for (auto* var_node : cluster_internals) {
graph->RemoveNode(var_node);
}
const std::unordered_set<const Node*> const_cluster{cluster.cbegin(),
cluster.cend()};
const std::unordered_set<const Node*> const_internals{
cluster_internals.cbegin(), cluster_internals.cend()};
ir::GraphSafeRemoveNodes(graph, const_cluster);
ir::GraphSafeRemoveNodes(graph, const_internals);
}

// Replacing Cinn subgraph to a special op node, whose op_type is
// Replacing Cinn subgraph to a cinn op node, whose op_type is
// kCinnLaunchOp, and inputs ares cluster_inputs and outputs are
// cluster_outputs.
// Meanwhile, move all links of cluster to the special op.
void ReplaceSubGraphWithSpecialOpNode(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals,
const std::string& compilation_key,
Graph* graph) {
// First, add the special op node whose name is "kCinnLaunchOp" into graph
auto special_op_node = AddSpecialOpToGraph(cluster_inputs, cluster_outputs,
compilation_key, graph);
// Second, remove all graph's links which are from or to cluster nodes
RemoveLinkFromCluster(cluster, cluster_inputs, cluster_outputs);
// Third, add new links from or to the the special op node
AddLinkToSpecialOp(cluster_inputs, cluster_outputs, special_op_node);
// Finally, remove the cinn sub graph from graph
// Meanwhile, move all links of cluster to the cinn op.
void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals,
const std::string& compilation_key,
Graph* graph) {
// Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster, cluster_inputs, cluster_outputs, compilation_key,
graph);
// Remove the cinn subgraph from graph
RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
}

Expand Down Expand Up @@ -376,12 +357,12 @@ void SearchAllSubgraphs(Graph* graph) {
// save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
// Replace the found cluster to a new special op node
ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs,
cluster_outputs, cluster_internals,
compilation_key, graph);
// Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs,
cluster_internals, compilation_key, graph);
}
}
} // namespace

void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); }

Expand Down
Loading