Skip to content

Commit

Permalink
Optimize the subgraph generated by BuildCinnPass (#36503)
Browse files Browse the repository at this point in the history
* add feed op and new var for the generated subgraph

* perfect the test script of build_cinn_pass 

* remove useless clear and perfect some annotation
  • Loading branch information
thisjiang authored Oct 19, 2021
1 parent 7b67f39 commit 6cdc5a4
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 29 deletions.
129 changes: 115 additions & 14 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,81 @@ using framework::ir::Node;
using GraphNodeVec = std::vector<Node*>;
using GraphNodeSet = std::unordered_set<Node*>;

// Deal with subgraph's feed input var node:
// create a new input var node and it's feed op node
void AddFeedOpAndVar(const std::unordered_set<Node*>& feed_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
Graph* graph) {
for (auto* old_var : feed_vars) {
// create feed op
OpDesc desc;
desc.SetType("feed");
desc.SetOutput("Out", {old_var->Name()});
auto op = graph->CreateOpNode(&desc);

// create new feed var node (SSAGraph)
auto var = graph->CreateVarNode(old_var->Var());

// link feed op and feed var
op->outputs = {var};
var->inputs = {op};

// link feed var to cluster op
for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
var->outputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->inputs.emplace_back(var);
}
// Do not need relink old op or old var here, they will be
// fixed in RemoveLinkFromCluster, here we just deal with
// new subgraph's node.
}
}
}

// Deal with subgraph's parameter var node:
// create a new input var node, it's data will get by scope,
// so it don't need feed op
void AddParamVar(const std::unordered_set<Node*>& param_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
Graph* graph) {
for (auto* old_var : param_vars) {
auto var = graph->CreateVarNode(old_var->Var());

for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
var->outputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->inputs.emplace_back(var);
}
}
}
}

// Deal with subgraph's outputs var node:
// create a new output var node and it's fetch op
void AddOutputVar(const std::unordered_set<Node*>& output_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
Graph* graph) {
for (auto* old_var : output_vars) {
auto var = graph->CreateVarNode(old_var->Var());

for (auto* old_op : old_var->inputs) {
if (cluster.count(old_op)) {
var->inputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->outputs.emplace_back(var);
}
}
}
}

// Create new subgraph with and op nodes are cluster nodes, and all
// var node are from internal nodes
std::unique_ptr<Graph> CreateNewSubGraph(
const GraphNodeSet& cluster, const GraphNodeSet& cluster_internals) {
std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_inputs) {
// Graph's constructor must has one parameter, and in our code,
// the ProgramDesc is useless, so here we pass a temporary object.
auto sub_graph = std::make_unique<Graph>(framework::ProgramDesc());
Expand All @@ -84,22 +155,45 @@ std::unique_ptr<Graph> CreateNewSubGraph(
old_var2new_var[var] = sub_node;
}

std::unordered_set<Node*> need_feed_vars;
std::unordered_set<Node *> param_vars, output_vars;
// the subgraph is independently, so here we only need link
// to the node in new subgraph, and discard the link to
// out-graph.
for (auto* op : cluster) {
for (auto* var : op->inputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]);
} else if (cluster_inputs.count(var)) {
if (var->Var()->IsParameter()) {
// Parameters have been preserved in scope, compared to feed var,
// param just need add new var and don't need add feed op.
// The var is used for check whether we need preserve the tensor
// when transform paddle scope to CINN scope.
param_vars.insert(var);
} else {
// When the var is subgraph input and the var is not parameter,
// we need add a new feed op to feed the var.
need_feed_vars.insert(var);
}
}
}
for (auto* var : op->outputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]);
} else {
// Create new output var node to guarantee the independency of
// subgraph. In other words, the subgraph has no connection with
// other graph, even the input graph.
output_vars.insert(var);
}
}
}

AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, sub_graph.get());
AddParamVar(param_vars, cluster, old_op2new_op, sub_graph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, sub_graph.get());

for (auto* var : cluster_internals) {
for (auto* op : var->inputs) {
if (cluster.count(op)) {
Expand All @@ -118,10 +212,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(

// This interface is used to classify all variables involved in a cluster into
// three types: inputs, outputs, and internals.
// Specially, the internal node is a node that only used by sub-graph, and
// The input node is some subgraph op's input but not any subgraph op's output.
// The output node is some subgraph op's output and some out-graph op's input.
// Specially, the internal node is a node that only used by subgraph, and
// out-graph should not using this node at all.
// inputs & outputs & internals == NULL
// inputs | outputs | internals == all graph node
// cluster_inputs & cluster_outputs & cluster_internals == NULL
// cluster_outputs | cluster_internals == all graph op's outputs node
void AnalyseClusterVariables(const GraphNodeSet& cluster,
GraphNodeSet* cluster_inputs,
GraphNodeSet* cluster_outputs,
Expand Down Expand Up @@ -154,10 +250,6 @@ void AnalyseClusterVariables(const GraphNodeSet& cluster,
}
}

// if a output node also exists in input list, remove.
for (auto* var_node : *cluster_inputs) {
cluster_outputs->erase(var_node);
}
// if a output node also exists in internal list, remove.
for (auto* var_node : *cluster_internals) {
cluster_outputs->erase(var_node);
Expand Down Expand Up @@ -206,14 +298,23 @@ void RemoveLinkFromCluster(const GraphNodeSet& cluster,

// removing useless link from cluster_inputs to cluster
for (auto* var_node : cluster_inputs) {
auto preserved_nodes = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_nodes.begin(), preserved_nodes.end());
auto preserved_ops = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end());
// According to SSA form, a var node must not be any two op's output,
// and the cluster_inputs var nodes is defined as an out-graph op's
// output, so the cluster_inputs var nodes are not any subgraph op's
// output. Do not reassign input list here.
}

// removing useless link from cluster to cluster_outputs
for (auto* var_node : cluster_outputs) {
auto preserved_nodes = get_preserved_ops(var_node->inputs);
var_node->inputs.assign(preserved_nodes.begin(), preserved_nodes.end());
auto preserved_ops = get_preserved_ops(var_node->inputs);
var_node->inputs.assign(preserved_ops.begin(), preserved_ops.end());

// Note that cluster_outputs var node maybe some subgraph op's input,
// here we need remove them.
preserved_ops = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end());
}
}

Expand Down Expand Up @@ -272,7 +373,7 @@ void SearchAllSubgraphs(Graph* graph,
&cluster_internals);

cinn_subgraphs->emplace_back(
CreateNewSubGraph(cluster_set, cluster_internals));
CreateNewSubGraph(cluster_set, cluster_internals, cluster_inputs));

// replacing subgraph to a new special op node
ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs,
Expand Down
Loading

0 comments on commit 6cdc5a4

Please sign in to comment.