diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 97f486065ac62..8afd20d6a00b5 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -446,6 +446,27 @@ std::vector TopologySortGraphByDescOrder(const Graph &graph) { return ret; } +void RemoveControlDepInputAndOuput(OpDesc *op_desc) { + auto remove_control_dep_var = [](VariableNameMap *var_name_map) { + for (auto &pair : *var_name_map) { + std::vector &var_names = pair.second; + auto it = var_names.begin(); + while (it != var_names.end()) { + if (it->find(ir::Node::kControlDepVarName) != std::string::npos) { + it = var_names.erase(it); + VLOG(6) << "Remove var " << *it; + } else { + ++it; + } + } + } + }; + + remove_control_dep_var(op_desc->MutableInputs()); + remove_control_dep_var(op_desc->MutableOutputs()); + op_desc->Flush(); +} + static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { desc->SetType("fill_constant"); desc->SetAttr( @@ -552,7 +573,9 @@ static void GraphToBlock(const Graph &graph, std::vector ops; GetGraphOpDesc(nodes, &ops); + for (auto &op : ops) { + RemoveControlDepInputAndOuput(&op); block->add_ops()->MergeFrom(*op.Proto()); } }