From 2e69429ee0b3eb97a949fce2f7bf26cdbf18d441 Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 5 Jun 2018 16:29:05 +0800 Subject: [PATCH 01/17] init --- .../fluid/inference/analysis/pass_manager.cc | 45 ++++++ .../fluid/inference/analysis/pass_manager.h | 136 ++++++++++++++++++ .../analysis/tensorrt_subgraph_pass.cc | 33 +++++ .../analysis/tensorrt_subgraph_pass.h | 47 ++++++ 4 files changed, 261 insertions(+) create mode 100644 paddle/fluid/inference/analysis/pass_manager.cc create mode 100644 paddle/fluid/inference/analysis/pass_manager.h create mode 100644 paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc create mode 100644 paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h diff --git a/paddle/fluid/inference/analysis/pass_manager.cc b/paddle/fluid/inference/analysis/pass_manager.cc new file mode 100644 index 0000000000000..159e7d643a2e2 --- /dev/null +++ b/paddle/fluid/inference/analysis/pass_manager.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/pass_manager.h" +#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void PassManagerMain::RunAll(const framework::proto::ProgramDesc &desc) { + for (auto &pass : data_) { + pass->RunAll(); + } +} + +// +// CustomIterPassManager +// + +DataFlowGraphPassManager::DataFlowGraphPassManager() { + type_ = kCustomIter; + Register("fluid_to_data_flow_graph", new FluidToDataFlowGraphPass); +} + +void DataFlowGraphPassManager::RunAll() { + for (auto &pass : data_) { + pass->Run(graph_); + } +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass_manager.h b/paddle/fluid/inference/analysis/pass_manager.h new file mode 100644 index 0000000000000..8f7de7b95f144 --- /dev/null +++ b/paddle/fluid/inference/analysis/pass_manager.h @@ -0,0 +1,136 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + * This file defines the interface for pass management. + */ + +#pragma once + +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/analysis/pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +class PassManager; + +/* + * PassManagerMain - Executes all the PassManagers. + */ +class PassManagerMain : public OrderedRegistry { + public: + static PassManagerMain &Global() { + static auto *x = new PassManagerMain; + return *x; + } + + // Execute all the PassManagers registered. + void RunAll(const framework::proto::ProgramDesc &desc); + + PADDLE_DISALLOW_COPY_AND_ASSIGN(PassManagerMain) + + protected: + DataFlowGraph data_flow_graph_; + + private: + PassManagerMain() = default; +}; + +/* + * PassManager is the base class for all pass managers, a pass manager has + * several Pass-es registered, and execute them in the right order. + */ +class PassManager : public OrderedRegistry { + public: + enum Type { + kUnknown = -1, + // The outer iteration is DFS algorithm. + kDFS_PM, + // The outer iteratoin is BFS algorithm. + kBFS_PM, + // The outer iteration follows a customized order. + kCustomIter + }; + + // Call all the passes' Initialize methods. The desc and data_flow_graph are + // globally shared, so pass them as the arguemnts for all the pass managers. + virtual bool Initialize(const framework::proto::ProgramDesc &desc, + DataFlowGraph *data_flow_graph) = 0; + + // Run all the passes. + virtual void RunAll() = 0; + + // Call all the passes' Finalize methods. + virtual bool Finalize() = 0; + + virtual ~PassManager() {} + + protected: + Type type_{Type::kUnknown}; +}; + +// A pass manager that traverse the graph in DFS order. +template +class DFSPassManager : public PassManager { + public: + DFSPassManager(const GraphType &graph); + + bool Initialize(const framework::proto::ProgramDesc &desc, + DataFlowGraph *data_flow_graph) override; + bool Finalize() override; + // DFS traverse the graph, call the passes in each step. + void RunAll() override; + + private: + GraphType graph_; +}; + +// TODO(Superjomn) Implement BFSPassManager if needed. + +/* + * A pass manager that traverse the graph in a customized order, it is a virtual + * class and need to be override by sub-classes. + */ +class DataFlowGraphPassManager : public PassManager { + public: + DataFlowGraphPassManager(); + bool Initialize(const framework::proto::ProgramDesc &desc, + DataFlowGraph *data_flow_graph) override { + graph_ = data_flow_graph; + for (auto &pass : data_) { + pass->Initialize(); + pass->Initialize(desc); + } + return true; + } + + void RunAll() override; + + bool Finalize() override { + for (auto &pass : data_) { + pass->Finalize(); + } + return true; + } + + private: + DataFlowGraph *graph_; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc new file mode 100644 index 0000000000000..b75df33b71311 --- /dev/null +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h" +#include "paddle/fluid/inference/analysis/subgraph_splitter.h" + +namespace paddle { +namespace inference { +namespace analysis { + +TensorRTSubGraphPass::TensorRTSubGraphPass( + const TensorRTSubGraphPass::NodeInsideSubgraphTeller &teller) + : node_inside_subgraph_teller_(teller) {} + +void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { + SubGraphFuse(graph, node_inside_subgraph_teller_); +} + +} // analysis +} // inference + +} // paddle diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h new file mode 100644 index 0000000000000..328ef376c152e --- /dev/null +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/inference/analysis/node.h" +#include "paddle/fluid/inference/analysis/pass.h" +#include "paddle/fluid/inference/analysis/subgraph_splitter.h" + +namespace paddle { +namespace inference { +namespace analysis { + +/* + * Parse the graph and replace TensorRT supported nodes with SubGraphNode + */ +class TensorRTSubGraphPass : public DataFlowGraphPass { + public: + // Tell whether to transform a sub-graph into TensorRT. + using NodeInsideSubgraphTeller = SubGraphFuse::NodeInsideSubgraphTeller; + + TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller); + + bool Initialize() override { return true; } + + // This class get a sub-graph as input and determine whether to transform this + // sub-graph into TensorRT. + void Run(DataFlowGraph* graph) override; + + private: + NodeInsideSubgraphTeller node_inside_subgraph_teller_; +}; + +} // namespace analysis +} // namespace inference +} // paddle From 1292dbc44674083d1c92be0d16f196b764f0659e Mon Sep 17 00:00:00 2001 From: superjomn Date: Wed, 6 Jun 2018 08:14:49 +0800 Subject: [PATCH 02/17] init --- paddle/fluid/inference/analysis/CMakeLists.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 9faf5bb303677..8516f69f2ee43 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -15,3 +15,9 @@ cc_test(test_subgraph_splitter DEPS analysis paddle_fluid tensor ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) + +cc_test(test_dfg_graphviz_draw_pass + SRCS dfg_graphviz_draw_pass_tester.cc + DEPS analysis + ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) +set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) From 9776b3ae837785d845e8c5ddf842b2e48b0a9472 Mon Sep 17 00:00:00 2001 From: superjomn Date: Wed, 6 Jun 2018 08:40:47 +0800 Subject: [PATCH 03/17] add dfg graphviz pass --- .../fluid/inference/analysis/CMakeLists.txt | 6 ++ .../analysis/dfg_graphviz_draw_pass.h | 68 +++++++++++++++++++ .../analysis/dfg_graphviz_draw_pass_tester.cc | 46 +++++++++++++ 3 files changed, 120 insertions(+) create mode 100644 paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h create mode 100644 paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 9faf5bb303677..8516f69f2ee43 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -15,3 +15,9 @@ cc_test(test_subgraph_splitter DEPS analysis paddle_fluid tensor ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) + +cc_test(test_dfg_graphviz_draw_pass + SRCS dfg_graphviz_draw_pass_tester.cc + DEPS analysis + ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) +set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h new file mode 100644 index 0000000000000..41d4475382bef --- /dev/null +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + * This file create an DFG_GraphvizDrawPass which helps to draw a data flow + * graph's structure using graphviz. + */ + +#pragma once + +#include +#include +#include "paddle/fluid/inference/analysis/pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +/* + * Output a dot file and write to some place. + */ +class DFG_GraphvizDrawPass : public DataFlowGraphPass { + public: + DFG_GraphvizDrawPass(const std::string& dir, const std::string& id) + : dir_(dir), id_(id) {} + + bool Initialize() override { return Pass::Initialize(); } + void Run(DataFlowGraph* graph) override { + auto content = Draw(graph); + std::ofstream file(GenDotPath()); + file.write(content.c_str(), content.size()); + file.close(); + LOG(INFO) << "draw dot to " << GenDotPath(); + } + + bool Finalize() override { return Pass::Finalize(); } + + Pass* CreatePrinterPass(std::ostream& os, + const std::string& banner) const override { + return nullptr; + } + + private: + // Path of the dot file to output. + std::string GenDotPath() const { + return dir_ + "/" + "graph_" + id_ + ".dot"; + } + + std::string Draw(DataFlowGraph* graph) { return graph->DotString(); } + + std::string dir_; + std::string id_; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc new file mode 100644 index 0000000000000..3fc1cc18b8554 --- /dev/null +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" + +#include +#include +#include +#include "paddle/fluid/inference/analysis/ut_helper.h" + +namespace paddle { +namespace inference { +namespace analysis { + +TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { + auto dfg = ProgramDescToDFG(desc); + DFG_GraphvizDrawPass pass("./", "test"); + pass.Initialize(); + pass.Run(&dfg); + + // test content + std::ifstream file("./graph_test.dot"); + ASSERT_TRUE(file.is_open()); + + std::string line; + int no{0}; + while (std::getline(file, line)) { + no++; + } + ASSERT_EQ(no, 82); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle From 4ff4ca05cbeed72267fec3557c321bf63c1ba5de Mon Sep 17 00:00:00 2001 From: superjomn Date: Wed, 6 Jun 2018 09:00:39 +0800 Subject: [PATCH 04/17] fix --- paddle/fluid/inference/analysis/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 8516f69f2ee43..50835784440bf 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -20,4 +20,4 @@ cc_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc DEPS analysis ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) +set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec) From c38690942d8d45567b83c1a0f73a9a51c666302e Mon Sep 17 00:00:00 2001 From: superjomn Date: Fri, 8 Jun 2018 08:16:51 +0800 Subject: [PATCH 05/17] add fusion --- .../fluid/inference/analysis/CMakeLists.txt | 10 ++++- paddle/fluid/inference/analysis/node.cc | 3 ++ paddle/fluid/inference/analysis/pass.h | 4 +- .../analysis/subgraph_splitter_tester.cc | 45 ++++++++++++++----- 4 files changed, 49 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 50835784440bf..8792b4cc2369d 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -1,5 +1,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init) -cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc +cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc + fluid_to_data_flow_graph_pass.cc + tensorrt_subgraph_pass.cc DEPS paddle_fluid) cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis) @@ -21,3 +23,9 @@ cc_test(test_dfg_graphviz_draw_pass DEPS analysis ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec) + +cc_test(test_tensorrt_subgraph_pass + SRCS tensorrt_subgraph_pass_tester.cc + DEPS analysis paddle_fluid tensor + ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) +set_tests_properties(test_tensorrt_subgraph_pass PROPERTIES DEPENDS test_word2vec) diff --git a/paddle/fluid/inference/analysis/node.cc b/paddle/fluid/inference/analysis/node.cc index fe060526080b1..37063278112ad 100644 --- a/paddle/fluid/inference/analysis/node.cc +++ b/paddle/fluid/inference/analysis/node.cc @@ -40,6 +40,9 @@ Node *NodeMap::Create(Node::Type type) { case Node::Type::kValue: nodes_.emplace_back(new Value); break; + case Node::Type ::kFunctionBlock: + nodes_.emplace_back(new FunctionBlock); + break; default: PADDLE_THROW("Not supported node type."); } diff --git a/paddle/fluid/inference/analysis/pass.h b/paddle/fluid/inference/analysis/pass.h index aa0e8667b5e4a..89c81a70ce863 100644 --- a/paddle/fluid/inference/analysis/pass.h +++ b/paddle/fluid/inference/analysis/pass.h @@ -50,7 +50,9 @@ class Pass { // Get a Pass appropriate to print the Node this pass operates on. virtual Pass *CreatePrinterPass(std::ostream &os, - const std::string &banner) const = 0; + const std::string &banner) const { + return nullptr; + } // Run on a single Node. virtual void Run(Node *x) { LOG(FATAL) << "not valid"; } diff --git a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc index 0644c0db12e3d..9ca5db7debb67 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc @@ -19,22 +19,23 @@ namespace paddle { namespace inference { namespace analysis { +SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) { + if (node->type() != Node::Type::kFunction) return false; + const auto* func = static_cast(node); + if (func->func_type() == "elementwise_add" || func->func_type() == "relu" || + func->func_type() == "conv2d" || func->func_type() == "mul" || + func->func_type() == "sigmoid" || func->func_type() == "softmax") { + LOG(INFO) << "sub-graph marked " << node->repr(); + return true; + } + return false; +}; + TEST_F(DFG_Tester, Split) { auto desc = LoadProgramDesc(); auto dfg = ProgramDescToDFG(desc); LOG(INFO) << "spliter\n" << dfg.DotString(); - SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) { - if (node->type() != Node::Type::kFunction) return false; - const auto* func = static_cast(node); - if (func->func_type() == "elementwise_add" || func->func_type() == "relu" || - func->func_type() == "conv2d" || func->func_type() == "mul" || - func->func_type() == "sigmoid" || func->func_type() == "softmax") { - LOG(INFO) << "sub-graph marked " << node->repr(); - return true; - } - return false; - }; ASSERT_GT(dfg.nodes.size(), 5UL); auto subgraphs = SubGraphSplitter(&dfg, teller)(); @@ -62,6 +63,28 @@ TEST_F(DFG_Tester, Split) { ASSERT_EQ(subgraphs.back().size(), 6UL); } +TEST_F(DFG_Tester, Fuse) { + auto desc = LoadProgramDesc(); + auto dfg = ProgramDescToDFG(desc); + + size_t count0 = dfg.nodes.size(); + + SubGraphFuse fuse(&dfg, teller); + fuse(); + + int count1=0; + for (auto& node : dfg.nodes.nodes()) { + if (node->deleted()) { + LOG(INFO) << "deleted " << node->repr(); + } + count1 += node->deleted(); + } + + // At least one nodes should be deleted. + ASSERT_EQ(dfg.nodes.size(), count0+1); // added a new FunctionBlock + ASSERT_EQ(6UL, count1); +} + } // namespace analysis } // namespace inference } // namespace paddle From 30a3ebc67bf0d89f80bc320b5ac3275958628a42 Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 12 Jun 2018 17:19:22 +0800 Subject: [PATCH 06/17] update --- .../fluid/inference/analysis/CMakeLists.txt | 1 + .../inference/analysis/data_flow_graph.cc | 15 +----- .../analysis/dfg_graphviz_draw_pass.cc | 54 +++++++++++++++++++ .../analysis/dfg_graphviz_draw_pass.h | 38 ++++++------- .../analysis/fluid_to_data_flow_graph_pass.cc | 4 +- paddle/fluid/inference/analysis/node.h | 22 ++++---- .../fluid/inference/analysis/pass_manager.cc | 4 +- .../fluid/inference/analysis/pass_manager.h | 8 ++- 8 files changed, 98 insertions(+), 48 deletions(-) create mode 100644 paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 8792b4cc2369d..1baf94ba24fd2 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -2,6 +2,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init) cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc tensorrt_subgraph_pass.cc + dfg_graphviz_draw_pass.cc DEPS paddle_fluid) cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis) diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc index 4220451e3caee..273a8f54da6dd 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/analysis/data_flow_graph.h" +#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/dot.h" namespace paddle { @@ -57,19 +58,7 @@ std::string DataFlowGraph::DotString() const { // Add nodes for (size_t i = 0; i < nodes.size(); i++) { const Node &node = nodes.Get(i); - switch (node.type()) { - case Node::Type::kValue: - dot.AddNode(node.repr(), node.dot_attrs()); - break; - case Node::Type::kFunction: - dot.AddNode(node.repr(), node.dot_attrs()); - break; - case Node::Type::kFunctionBlock: - dot.AddNode(node.repr(), node.dot_attrs()); - break; - default: - PADDLE_THROW("unsupported Node type %d", static_cast(node.type())); - } + dot.AddNode(node.repr(), node.dot_attrs()); } // Add edges diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc new file mode 100644 index 0000000000000..afffb3feb0c51 --- /dev/null +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) { + auto content = Draw(graph); + std::ofstream file(GenDotPath()); + file.write(content.c_str(), content.size()); + file.close(); + LOG(INFO) << "draw dot to " << GenDotPath(); +} + +std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) { + Dot dot; + // Add nodes + for (size_t i = 0; i < graph->nodes.size(); i++) { + const Node &node = graph->nodes.Get(i); + if (config_.display_deleted_node || !node.deleted()) { + dot.AddNode(node.repr(), node.dot_attrs()); + } + } + // Add edges + for (size_t i = 0; i < graph->nodes.size(); i++) { + const Node &node = graph->nodes.Get(i); + if (!config_.display_deleted_node && node.deleted()) continue; + for (auto &in : node.inlinks) { + if (!config_.display_deleted_node && in->deleted()) continue; + for (auto &in : node.inlinks) { + dot.AddEdge(in->repr(), node.repr(), {}); + } + } + } + return dot.Build(); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h index 41d4475382bef..a8d66852c038a 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/pass.h" namespace paddle { @@ -32,35 +33,34 @@ namespace analysis { */ class DFG_GraphvizDrawPass : public DataFlowGraphPass { public: - DFG_GraphvizDrawPass(const std::string& dir, const std::string& id) - : dir_(dir), id_(id) {} + struct Config { + Config(const std::string &dir, const std::string &id, + bool display_deleted_node = false) + : dir(dir), id(id), display_deleted_node(display_deleted_node) {} - bool Initialize() override { return Pass::Initialize(); } - void Run(DataFlowGraph* graph) override { - auto content = Draw(graph); - std::ofstream file(GenDotPath()); - file.write(content.c_str(), content.size()); - file.close(); - LOG(INFO) << "draw dot to " << GenDotPath(); - } + // The directory to store the .dot or .png files. + const std::string dir; + // The identifier for this dot file. + const std::string id; + // Whether to display deleted nodes, default false. + const bool display_deleted_node; + }; - bool Finalize() override { return Pass::Finalize(); } + DFG_GraphvizDrawPass(const Config &config) : config_(config) {} - Pass* CreatePrinterPass(std::ostream& os, - const std::string& banner) const override { - return nullptr; - } + bool Initialize() override { return Pass::Initialize(); } + void Run(DataFlowGraph *graph) override; + bool Finalize() override { return Pass::Finalize(); } private: // Path of the dot file to output. std::string GenDotPath() const { - return dir_ + "/" + "graph_" + id_ + ".dot"; + return config_.dir + "/" + "graph_" + config_.id + ".dot"; } - std::string Draw(DataFlowGraph* graph) { return graph->DotString(); } + std::string Draw(DataFlowGraph *graph); - std::string dir_; - std::string id_; + Config config_; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc index 9f67c989cca4a..2f022dd2aeea4 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc @@ -41,7 +41,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { const auto &var = main_block.vars(i); auto *v = graph->nodes.Create(Node::Type::kValue); v->SetName(var.name()); - v->SetExtraInfo(const_cast(static_cast(&var))); + v->SetPbDesc(const_cast(static_cast(&var))); var2id[var.name()] = v->id(); } for (int i = 0; i < main_block.ops_size(); i++) { @@ -51,7 +51,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { static_cast(o)->SetFuncType(op.type()); // Link to the original protobuf message's memory, make it easier to // generate from a data flow graph to fluid ProgramDesc. - o->SetExtraInfo(const_cast(static_cast(&op))); + o->SetPbDesc(const_cast(static_cast(&op))); // set inputs and outputs // TODO(Superjomn) make sure the InputNames is the real variable name. for (int j = 0; j < op.inputs_size(); j++) { diff --git a/paddle/fluid/inference/analysis/node.h b/paddle/fluid/inference/analysis/node.h index 7972ca25c9218..fa0afe4a6bbfe 100644 --- a/paddle/fluid/inference/analysis/node.h +++ b/paddle/fluid/inference/analysis/node.h @@ -71,12 +71,17 @@ class Node { // Get an additional attribute and convert it to T data type. NOTE this will // silently create a new attribute if not exists. - Attr &attr(const std::string &name) { return attrs_[name]; } + Attr &attr(const std::string &name) const { return attrs_[name]; } int id() const { return id_; } - bool deleted() const { return deleted_; } + // The Protobuf description is set/get with a void* to decouple Node interface + // from a specific kind of Protobuf message. + void SetPbDesc(void *pb) { attr("pb_desc").Pointer() = pb; } + void *pb_desc() const { return attr("pb_desc").Pointer(); } + void SetDeleted() { deleted_ = true; } + bool deleted() const { return deleted_; } void SetName(const std::string &name) { name_ = name; } const std::string &name() const { return name_; } @@ -84,29 +89,25 @@ class Node { void SetType(Type type) { type_ = type; } Type type() const { return type_; } - void *extra_info() const { return extra_info_; } - void SetExtraInfo(void *extra_info) { extra_info_ = extra_info; } - // Input links. std::vector inlinks; // Output links. std::vector outlinks; // A helper class to maintain the status from Pass. - // TODO(superjomn) add a checker here to ensure the T is primary. struct Attr { // NOTE T should be a primary type or a struct combined by several primary // types. // NOTE the STL containers should not use here. // Some usages - // Attr attr; - // T data; - // attr.data.assign((char*)data, sizeof(data)); + // Attr attr; + // attr.Bool() = true; bool &Bool() { return As(); } float &Float() { return As(); } int32_t &Int32() { return As(); } int64_t &Int64() { return As(); } + void *&Pointer() { return As(); } private: template @@ -130,6 +131,7 @@ class Node { size_t type_hash_{std::numeric_limits::max()}; }; + // Type checks. bool IsFunction() const { return type_ == Node::Type::kFunction; } bool IsValue() const { return type_ == Node::Type::kValue; } bool IsFunctionBlock() const { return type_ == Node::Type::kFunctionBlock; } @@ -149,7 +151,7 @@ class Node { // Mark this node is deleted by some pass. bool deleted_{false}; - void *extra_info_; + // void *extra_info_; mutable std::unordered_map attrs_; }; diff --git a/paddle/fluid/inference/analysis/pass_manager.cc b/paddle/fluid/inference/analysis/pass_manager.cc index 159e7d643a2e2..a737eb8657218 100644 --- a/paddle/fluid/inference/analysis/pass_manager.cc +++ b/paddle/fluid/inference/analysis/pass_manager.cc @@ -29,12 +29,12 @@ void PassManagerMain::RunAll(const framework::proto::ProgramDesc &desc) { // CustomIterPassManager // -DataFlowGraphPassManager::DataFlowGraphPassManager() { +DFG_PassManager::DFG_PassManager() { type_ = kCustomIter; Register("fluid_to_data_flow_graph", new FluidToDataFlowGraphPass); } -void DataFlowGraphPassManager::RunAll() { +void DFG_PassManager::RunAll() { for (auto &pass : data_) { pass->Run(graph_); } diff --git a/paddle/fluid/inference/analysis/pass_manager.h b/paddle/fluid/inference/analysis/pass_manager.h index 8f7de7b95f144..567e27f6b46a0 100644 --- a/paddle/fluid/inference/analysis/pass_manager.h +++ b/paddle/fluid/inference/analysis/pass_manager.h @@ -105,9 +105,9 @@ class DFSPassManager : public PassManager { * A pass manager that traverse the graph in a customized order, it is a virtual * class and need to be override by sub-classes. */ -class DataFlowGraphPassManager : public PassManager { +class DFG_PassManager : public PassManager { public: - DataFlowGraphPassManager(); + DFG_PassManager(); bool Initialize(const framework::proto::ProgramDesc &desc, DataFlowGraph *data_flow_graph) override { graph_ = data_flow_graph; @@ -131,6 +131,10 @@ class DataFlowGraphPassManager : public PassManager { DataFlowGraph *graph_; }; + +// Run all the pass managers to analysis and optimize the graph. +static void RunAnalysis() {} + } // namespace analysis } // namespace inference } // namespace paddle From 9daeed27b98caccbc6d88bc39e5df690e59e88d6 Mon Sep 17 00:00:00 2001 From: superjomn Date: Wed, 13 Jun 2018 09:39:26 +0800 Subject: [PATCH 07/17] add tester --- .../analysis/tensorrt_subgraph_pass_tester.cc | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc new file mode 100644 index 0000000000000..9d3c790746b15 --- /dev/null +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h" + +#include +#include +#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" +#include "paddle/fluid/inference/analysis/ut_helper.h" + +namespace paddle { +namespace inference { +namespace analysis { + +DEFINE_string(model_dir, "", "inference test model dir"); + +TEST(TensorRTSubGraph, single_pass) { + auto desc = LoadProgramDesc(); + auto dfg = ProgramDescToDFG(desc); + + SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) { + if (node->type() != Node::Type::kFunction) return false; + const auto* func = static_cast(node); + if (func->func_type() == "elementwise_add" || func->func_type() == "relu" || + func->func_type() == "conv2d" || func->func_type() == "mul" || + func->func_type() == "sigmoid" || func->func_type() == "softmax") { + LOG(INFO) << "sub-graph marked " << node->repr(); + return true; + } + return false; + }; + + DFG_GraphvizDrawPass::Config config{"./", "test"}; + DFG_GraphvizDrawPass dfg_pass(config); + dfg_pass.Initialize(); + + DFG_GraphvizDrawPass dfg_pass1(config); + dfg_pass1.Initialize(); + + dfg_pass.Run(&dfg); + + TensorRTSubGraphPass trt_pass(std::move(teller)); + trt_pass.Initialize(); + + trt_pass.Run(&dfg); + + dfg_pass1.Run(&dfg); + + // Check the TRT op's block desc + for (auto node : dfg.nodes.nodes()) { + if (node->IsFunctionBlock() ) { + auto& desc = + } + } +} + +TEST(TensorRTSubGraph, pass_manager) {} + +} // namespace analysis +} // namespace inference +} // namespace paddle From 743f1f80c485e157420e0414806eaa010f1ca784 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sat, 16 Jun 2018 07:03:55 +0000 Subject: [PATCH 08/17] update --- paddle/fluid/inference/analysis/argument.cc | 1 + paddle/fluid/inference/analysis/argument.h | 32 +++++++ .../fluid/inference/analysis/pass_manager.h | 88 +++++-------------- .../inference/analysis/pass_manager_tester.cc | 29 ++++++ .../analysis/tensorrt_subgraph_pass_tester.cc | 1 - 5 files changed, 85 insertions(+), 66 deletions(-) create mode 100644 paddle/fluid/inference/analysis/argument.cc create mode 100644 paddle/fluid/inference/analysis/argument.h create mode 100644 paddle/fluid/inference/analysis/pass_manager_tester.cc diff --git a/paddle/fluid/inference/analysis/argument.cc b/paddle/fluid/inference/analysis/argument.cc new file mode 100644 index 0000000000000..5f931a3941907 --- /dev/null +++ b/paddle/fluid/inference/analysis/argument.cc @@ -0,0 +1 @@ +#include "paddle/fluid/inference/analysis/argument.h" diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h new file mode 100644 index 0000000000000..b36d8786d6fc0 --- /dev/null +++ b/paddle/fluid/inference/analysis/argument.h @@ -0,0 +1,32 @@ +/* + * This file defines the class Argument, which is the input and output of the + * analysis module. All the fields that needed either by Passes or PassManagers + * are contained in Argument. + * + * TODO(Superjomn) Find some way better to contain the fields when it grow too + * big. + */ + +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/analysis/data_flow_graph.h" + +namespace paddle { +namespace inference { +namespace analysis { + +/* + * The argument definition of both Pass and PassManagers. + * + * All the fields should be registered here for clearness. + */ +struct Argument { + // The graph that process by the Passes or PassManagers. + std::unique_ptr main_dfg; + + // The original program desc. + std::unique_ptr origin_program_desc; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass_manager.h b/paddle/fluid/inference/analysis/pass_manager.h index 567e27f6b46a0..0617504b7e796 100644 --- a/paddle/fluid/inference/analysis/pass_manager.h +++ b/paddle/fluid/inference/analysis/pass_manager.h @@ -13,7 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. */ /* - * This file defines the interface for pass management. + * This file defines the logic of pass management. The analysis for inference is + * a pipeline of Passes, a PassManager is a agency that helps to manage the + * executation of the Passes. + * + * There are two modes of Passes, the first one is called NodePass and takes + * an Node as input and output; the second one is called DFGPass and takes a + * DFG(Data Flow Graph) as input and output. It is hard to put all the passes in + * the same pipeline, there are two kinds of PassManagers, both takes a DFG as + * input and output a DFG, but the Passes inside are different: + * + * 1. NodePassManager: the passes inside are all NodePasses, it can have + * different graph trivial algorithm, for example, DFS_NodePassManager will + * trigger the passes in depth first order; + * 2. DfgPassManager: the passes inside are all DfgPasses. */ #pragma once @@ -26,46 +39,12 @@ namespace paddle { namespace inference { namespace analysis { -class PassManager; - -/* - * PassManagerMain - Executes all the PassManagers. - */ -class PassManagerMain : public OrderedRegistry { - public: - static PassManagerMain &Global() { - static auto *x = new PassManagerMain; - return *x; - } - - // Execute all the PassManagers registered. - void RunAll(const framework::proto::ProgramDesc &desc); - - PADDLE_DISALLOW_COPY_AND_ASSIGN(PassManagerMain) - - protected: - DataFlowGraph data_flow_graph_; - - private: - PassManagerMain() = default; -}; - /* * PassManager is the base class for all pass managers, a pass manager has - * several Pass-es registered, and execute them in the right order. + * several Pass-es registered, and execute them in the linear order. */ class PassManager : public OrderedRegistry { public: - enum Type { - kUnknown = -1, - // The outer iteration is DFS algorithm. - kDFS_PM, - // The outer iteratoin is BFS algorithm. - kBFS_PM, - // The outer iteration follows a customized order. - kCustomIter - }; - // Call all the passes' Initialize methods. The desc and data_flow_graph are // globally shared, so pass them as the arguemnts for all the pass managers. virtual bool Initialize(const framework::proto::ProgramDesc &desc, @@ -77,37 +56,20 @@ class PassManager : public OrderedRegistry { // Call all the passes' Finalize methods. virtual bool Finalize() = 0; - virtual ~PassManager() {} - - protected: - Type type_{Type::kUnknown}; -}; - -// A pass manager that traverse the graph in DFS order. -template -class DFSPassManager : public PassManager { - public: - DFSPassManager(const GraphType &graph); - - bool Initialize(const framework::proto::ProgramDesc &desc, - DataFlowGraph *data_flow_graph) override; - bool Finalize() override; - // DFS traverse the graph, call the passes in each step. - void RunAll() override; + // Short identifier. + virtual std::string repr() const = 0; + // Long description. + virtual std::string description() const = 0; - private: - GraphType graph_; + virtual ~PassManager() = default; }; -// TODO(Superjomn) Implement BFSPassManager if needed. - /* - * A pass manager that traverse the graph in a customized order, it is a virtual - * class and need to be override by sub-classes. + * A pass manager that process a DFG. */ -class DFG_PassManager : public PassManager { +class DfgPassManager : public PassManager { public: - DFG_PassManager(); + DfgPassManager(); bool Initialize(const framework::proto::ProgramDesc &desc, DataFlowGraph *data_flow_graph) override { graph_ = data_flow_graph; @@ -131,10 +93,6 @@ class DFG_PassManager : public PassManager { DataFlowGraph *graph_; }; - -// Run all the pass managers to analysis and optimize the graph. -static void RunAnalysis() {} - } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass_manager_tester.cc b/paddle/fluid/inference/analysis/pass_manager_tester.cc new file mode 100644 index 0000000000000..cdb4edfbe845c --- /dev/null +++ b/paddle/fluid/inference/analysis/pass_manager_tester.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/analysis/pass_manager.h" + +#include + +namespace paddle { +namespace inference { +namespace analysis { + + + + + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc index 9d3c790746b15..5251c86d6dbfc 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc @@ -60,7 +60,6 @@ TEST(TensorRTSubGraph, single_pass) { // Check the TRT op's block desc for (auto node : dfg.nodes.nodes()) { if (node->IsFunctionBlock() ) { - auto& desc = } } } From 8db096447ec8fc94dafa9b4fdebb97827519b039 Mon Sep 17 00:00:00 2001 From: superjomn Date: Sat, 16 Jun 2018 15:08:40 +0800 Subject: [PATCH 09/17] add data_flow_graph_to_fluid_pass --- .../analysis/data_flow_graph_to_fluid_pass.cc | 75 +++++++++++++++++++ .../analysis/data_flow_graph_to_fluid_pass.h | 54 +++++++++++++ 2 files changed, 129 insertions(+) create mode 100644 paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc create mode 100644 paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc new file mode 100644 index 0000000000000..7eed743dab2e8 --- /dev/null +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" +#include "paddle/fluid/framework/proto_desc.h" + +namespace paddle { +namespace inference { +namespace analysis { + +bool DataFlowGraphToFluidPass::Initialize(framework::proto::ProgramDesc* desc) { + desc_ = desc; + // Here some logic from program_desc.cc and will not add new interfaces into + // framework::ProgramDesc class, use some UT to assure the correctness. + auto* block = desc_->mutable_blocks()->Add(); + block->set_idx(framework::kRootBlockIndex); + block->set_parent_idx(framework::kNoneBlockIndex); + return true; +} + +bool DataFlowGraphToFluidPass::Finalize() { return true; } + +void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) { + auto traits = GraphTraits(graph); + for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) { + if (it->deleted()) continue; + switch (it->type()) { + case Node::Type::kFunction: + LOG(INFO) << "add function " << it->name(); + AddFluidOp(&(*it)); + break; + case Node::Type::kFunctionBlock: + AddEngineOp(&(*it)); + break; + default: + continue; + } + } +} + +void DataFlowGraphToFluidPass::AddFluidOp(Node* node) { + LOG(INFO) << "processing func " << node->name(); + auto* ori_op = static_cast(node->extra_info()); + // currently only the main block is analyzed. + auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); + auto* op = main_block->add_ops(); + LOG(INFO) << "to copy the op"; + *op = *ori_op; // copy the attributes, by default, these will not be changed + // by analysis phrase. + // The inputs and outputs of the existing ops are not changed by tensorrt + // subgraph pass. + // NOTE It might be changed by other passes in the long run. +} + +void DataFlowGraphToFluidPass::AddEngineOp(Node* node) { + // auto* ori_op = static_cast(node->extra_info()); + // auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); + // auto* op = main_block->add_ops(); + // TODO(Superjomn) Here need to expose some arguments for default setting. +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h new file mode 100644 index 0000000000000..4ff8712bb3958 --- /dev/null +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +/* + * This file implements the transformation from fluid ProgramDesc to data flow + * graph. + */ + +#pragma once + +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/analysis/data_flow_graph.h" +#include "paddle/fluid/inference/analysis/pass.h" + +namespace paddle { +namespace inference { +namespace analysis { +class DataFlowGraphToFluidPass final : public DataFlowGraphPass { + public: + DataFlowGraphToFluidPass() = default; + + bool Initialize(framework::proto::ProgramDesc *desc) override; + bool Finalize() override; + + void Run(DataFlowGraph *graph) override; + + Pass *CreatePrinterPass(std::ostream &os, + const std::string &banner) const override { + return nullptr; + } + + protected: + // Add a Fluid Op into the ProgramDesc. + void AddFluidOp(Node *node); + // Add a EngineOp into the ProgramDesc. + void AddEngineOp(Node *node); + + private: + framework::proto::ProgramDesc *desc_; +}; +} // namespace analysis +} // namespace inference +} // namespace paddle From 1ab4dbb554604bb2791d2c7e30c24a3cee43698e Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sat, 16 Jun 2018 13:59:31 +0000 Subject: [PATCH 10/17] update --- .../fluid/inference/analysis/CMakeLists.txt | 37 +++++++++---------- paddle/fluid/inference/analysis/argument.cc | 14 +++++++ paddle/fluid/inference/analysis/argument.h | 20 ++++++++++ .../inference/analysis/data_flow_graph.cc | 2 +- .../analysis/data_flow_graph_to_fluid_pass.cc | 8 ++-- .../analysis/data_flow_graph_to_fluid_pass.h | 2 +- .../data_flow_graph_to_fluid_pass_tester.cc | 5 +-- .../analysis/dfg_graphviz_draw_pass.h | 1 - .../analysis/dfg_graphviz_draw_pass_tester.cc | 7 ++-- .../analysis/fluid_to_data_flow_graph_pass.cc | 15 ++++---- .../analysis/fluid_to_data_flow_graph_pass.h | 6 +-- .../fluid_to_data_flow_graph_pass_tester.cc | 6 +-- paddle/fluid/inference/analysis/helper.h | 1 + paddle/fluid/inference/analysis/node.cc | 2 +- paddle/fluid/inference/analysis/pass.h | 16 +++++--- .../fluid/inference/analysis/pass_manager.cc | 25 +------------ .../fluid/inference/analysis/pass_manager.h | 17 ++++----- .../inference/analysis/pass_manager_tester.cc | 8 +--- .../analysis/subgraph_splitter_tester.cc | 4 +- .../analysis/tensorrt_subgraph_pass.h | 2 +- .../analysis/tensorrt_subgraph_pass_tester.cc | 2 +- paddle/fluid/inference/analysis/ut_helper.h | 11 ++++-- 22 files changed, 113 insertions(+), 98 deletions(-) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 1baf94ba24fd2..daebeeca60463 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -1,6 +1,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init) cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc + data_flow_graph_to_fluid_pass.cc tensorrt_subgraph_pass.cc dfg_graphviz_draw_pass.cc DEPS paddle_fluid) @@ -9,24 +10,22 @@ cc_test(test_dot SRCS dot_tester.cc DEPS analysis) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) -cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec) +function (inference_analysis_test TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS) + cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) -cc_test(test_subgraph_splitter - SRCS subgraph_splitter_tester.cc - DEPS analysis paddle_fluid tensor - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) + cc_test(${TARGET} + SRCS "${analysis_test_SRCS}" + DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid + ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5) + set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) +endfunction(inference_analysis_test) -cc_test(test_dfg_graphviz_draw_pass - SRCS dfg_graphviz_draw_pass_tester.cc - DEPS analysis - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec) - -cc_test(test_tensorrt_subgraph_pass - SRCS tensorrt_subgraph_pass_tester.cc - DEPS analysis paddle_fluid tensor - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) -set_tests_properties(test_tensorrt_subgraph_pass PROPERTIES DEPENDS test_word2vec) +inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) +inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) +inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) +inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) +inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) +inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) diff --git a/paddle/fluid/inference/analysis/argument.cc b/paddle/fluid/inference/analysis/argument.cc index 5f931a3941907..cb0263d5d98e8 100644 --- a/paddle/fluid/inference/analysis/argument.cc +++ b/paddle/fluid/inference/analysis/argument.cc @@ -1 +1,15 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/inference/analysis/argument.h" diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index b36d8786d6fc0..4b46eef76921a 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -1,3 +1,17 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* * This file defines the class Argument, which is the input and output of the * analysis module. All the fields that needed either by Passes or PassManagers @@ -27,6 +41,12 @@ struct Argument { std::unique_ptr origin_program_desc; }; +#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ + if (!(field__)) { \ + LOG(ERROR) << "field " << #field__ << " should be set."; \ + return false; \ + } + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc index 273a8f54da6dd..c30a7c26cecbe 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/analysis/data_flow_graph.h" -#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/dot.h" +#include "paddle/fluid/inference/analysis/node.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 7eed743dab2e8..f7d4cca2132d1 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -19,8 +19,10 @@ namespace paddle { namespace inference { namespace analysis { -bool DataFlowGraphToFluidPass::Initialize(framework::proto::ProgramDesc* desc) { - desc_ = desc; +bool DataFlowGraphToFluidPass::Initialize(Argument* argument) { + ANALYSIS_ARGUMENT_CHECK_FIELD(argument) + ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) + desc_ = argument->origin_program_desc.get(); // Here some logic from program_desc.cc and will not add new interfaces into // framework::ProgramDesc class, use some UT to assure the correctness. auto* block = desc_->mutable_blocks()->Add(); @@ -51,7 +53,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) { void DataFlowGraphToFluidPass::AddFluidOp(Node* node) { LOG(INFO) << "processing func " << node->name(); - auto* ori_op = static_cast(node->extra_info()); + auto* ori_op = static_cast(node->pb_desc()); // currently only the main block is analyzed. auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto* op = main_block->add_ops(); diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h index 4ff8712bb3958..612763b7dda07 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h @@ -30,7 +30,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass { public: DataFlowGraphToFluidPass() = default; - bool Initialize(framework::proto::ProgramDesc *desc) override; + bool Initialize(Argument *argument) override; bool Finalize() override; void Run(DataFlowGraph *graph) override; diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc index dcee75cee50ed..d8fc5e580a98f 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc @@ -27,13 +27,12 @@ namespace inference { namespace analysis { TEST_F(DFG_Tester, Test) { - framework::proto::ProgramDesc new_desc; DataFlowGraph graph; FluidToDataFlowGraphPass pass0; DataFlowGraphToFluidPass pass1; - pass0.Initialize(desc); - pass1.Initialize(&new_desc); + ASSERT_TRUE(pass0.Initialize(&argument)); + ASSERT_TRUE(pass1.Initialize(&argument)); pass0.Run(&graph); pass1.Run(&graph); diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h index a8d66852c038a..60773571dfc40 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h @@ -48,7 +48,6 @@ class DFG_GraphvizDrawPass : public DataFlowGraphPass { DFG_GraphvizDrawPass(const Config &config) : config_(config) {} - bool Initialize() override { return Pass::Initialize(); } void Run(DataFlowGraph *graph) override; bool Finalize() override { return Pass::Finalize(); } diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc index 3fc1cc18b8554..0ce35cb974676 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc @@ -24,9 +24,10 @@ namespace inference { namespace analysis { TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { - auto dfg = ProgramDescToDFG(desc); - DFG_GraphvizDrawPass pass("./", "test"); - pass.Initialize(); + auto dfg = ProgramDescToDFG(*argument.origin_program_desc); + DFG_GraphvizDrawPass::Config config("./", "test"); + DFG_GraphvizDrawPass pass(config); + pass.Initialize(&argument); pass.Run(&dfg); // test content diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc index 2f022dd2aeea4..24cbf949d006d 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc @@ -21,13 +21,14 @@ namespace paddle { namespace inference { namespace analysis { -FluidToDataFlowGraphPass::FluidToDataFlowGraphPass() {} - -bool FluidToDataFlowGraphPass::Initialize() { return Pass::Initialize(); } - -bool FluidToDataFlowGraphPass::Initialize( - const framework::proto::ProgramDesc &desc) { - desc_ = &desc; +bool FluidToDataFlowGraphPass::Initialize(Argument *argument) { + ANALYSIS_ARGUMENT_CHECK_FIELD(argument); + ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc); + PADDLE_ENFORCE(argument); + if (!argument->main_dfg) { + argument->main_dfg.reset(new DataFlowGraph); + } + desc_ = argument->origin_program_desc.get(); return true; } diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h index 33517e57becdf..29ee572c4dfa5 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h @@ -34,9 +34,9 @@ namespace analysis { */ class FluidToDataFlowGraphPass final : public DataFlowGraphPass { public: - FluidToDataFlowGraphPass(); - bool Initialize() override; - bool Initialize(const framework::proto::ProgramDesc &desc) override; + FluidToDataFlowGraphPass() = default; + // bool Initialize(const framework::proto::ProgramDesc &desc) override; + bool Initialize(Argument *argument) override; bool Finalize() override; void Run(DataFlowGraph *graph) override; diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc index 817d32c92cdbd..cfbbc284e491b 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc @@ -23,11 +23,11 @@ namespace analysis { TEST_F(DFG_Tester, Init) { FluidToDataFlowGraphPass pass; - pass.Initialize(); - pass.Initialize(desc); + pass.Initialize(&argument); DataFlowGraph graph; pass.Run(&graph); - ASSERT_GT(graph.nodes.size(), 0); + // Analysis is sensitive to ProgramDesc, careful to change the original model. + ASSERT_EQ(graph.nodes.size(), 37); pass.Finalize(); LOG(INFO) << '\n' << graph.DotString(); } diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index 153dca576bd67..7828c89e85eab 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -60,6 +60,7 @@ struct DataTypeNamer { SET_TYPE(int); SET_TYPE(bool); SET_TYPE(float); + SET_TYPE(void *); } std::unordered_map #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/node.h" @@ -33,16 +34,21 @@ class Pass { virtual ~Pass() {} // Virtual method overridden by subclasses to do only necessary initialization // before any pass is run. - virtual bool Initialize() { return false; } + // virtual bool Initialize() { return false; } // There is some passes such as FlowToDataFlowGraphPass that needs a // ProgramDesc. Here use the native ProgramDesc ProtoBuf message, so that it // only couple with the proto file. - virtual bool Initialize(const framework::proto::ProgramDesc &desc) { - return false; - } + // virtual bool Initialize(const framework::proto::ProgramDesc &desc) { return + // false; } // There are some Passes such as DataFlowGraphToFluidPass that will output a // ProgramDesc. - virtual bool Initialize(framework::proto::ProgramDesc *desc) { return false; } + // virtual bool Initialize(framework::proto::ProgramDesc *desc) { return + // false; } + + // Mutable Pass. + virtual bool Initialize(Argument *argument) { return false; } + // Readonly Pass. + virtual bool Initialize(const Argument &argument) { return false; } // Virtual method overriden by subclasses to do any necessary clean up after // all passes have run. diff --git a/paddle/fluid/inference/analysis/pass_manager.cc b/paddle/fluid/inference/analysis/pass_manager.cc index a737eb8657218..b07f2e50eacfd 100644 --- a/paddle/fluid/inference/analysis/pass_manager.cc +++ b/paddle/fluid/inference/analysis/pass_manager.cc @@ -17,29 +17,6 @@ limitations under the License. */ namespace paddle { namespace inference { -namespace analysis { - -void PassManagerMain::RunAll(const framework::proto::ProgramDesc &desc) { - for (auto &pass : data_) { - pass->RunAll(); - } -} - -// -// CustomIterPassManager -// - -DFG_PassManager::DFG_PassManager() { - type_ = kCustomIter; - Register("fluid_to_data_flow_graph", new FluidToDataFlowGraphPass); -} - -void DFG_PassManager::RunAll() { - for (auto &pass : data_) { - pass->Run(graph_); - } -} - -} // namespace analysis +namespace analysis {} // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass_manager.h b/paddle/fluid/inference/analysis/pass_manager.h index 0617504b7e796..850b02be04223 100644 --- a/paddle/fluid/inference/analysis/pass_manager.h +++ b/paddle/fluid/inference/analysis/pass_manager.h @@ -47,8 +47,8 @@ class PassManager : public OrderedRegistry { public: // Call all the passes' Initialize methods. The desc and data_flow_graph are // globally shared, so pass them as the arguemnts for all the pass managers. - virtual bool Initialize(const framework::proto::ProgramDesc &desc, - DataFlowGraph *data_flow_graph) = 0; + virtual bool Initialize(Argument* argument) { return false; } + virtual bool Initialize(const Argument& argument) { return false; } // Run all the passes. virtual void RunAll() = 0; @@ -70,12 +70,9 @@ class PassManager : public OrderedRegistry { class DfgPassManager : public PassManager { public: DfgPassManager(); - bool Initialize(const framework::proto::ProgramDesc &desc, - DataFlowGraph *data_flow_graph) override { - graph_ = data_flow_graph; - for (auto &pass : data_) { - pass->Initialize(); - pass->Initialize(desc); + bool Initialize(Argument* argument) override { + for (auto& pass : data_) { + PADDLE_ENFORCE(pass->Initialize(argument)); } return true; } @@ -83,14 +80,14 @@ class DfgPassManager : public PassManager { void RunAll() override; bool Finalize() override { - for (auto &pass : data_) { + for (auto& pass : data_) { pass->Finalize(); } return true; } private: - DataFlowGraph *graph_; + DataFlowGraph* graph_; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/pass_manager_tester.cc b/paddle/fluid/inference/analysis/pass_manager_tester.cc index cdb4edfbe845c..b2a71c0d12af1 100644 --- a/paddle/fluid/inference/analysis/pass_manager_tester.cc +++ b/paddle/fluid/inference/analysis/pass_manager_tester.cc @@ -18,12 +18,6 @@ limitations under the License. */ namespace paddle { namespace inference { -namespace analysis { - - - - - -} // namespace analysis +namespace analysis {} // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc index 9ca5db7debb67..8134494f8bccb 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc @@ -72,7 +72,7 @@ TEST_F(DFG_Tester, Fuse) { SubGraphFuse fuse(&dfg, teller); fuse(); - int count1=0; + int count1 = 0; for (auto& node : dfg.nodes.nodes()) { if (node->deleted()) { LOG(INFO) << "deleted " << node->repr(); @@ -81,7 +81,7 @@ TEST_F(DFG_Tester, Fuse) { } // At least one nodes should be deleted. - ASSERT_EQ(dfg.nodes.size(), count0+1); // added a new FunctionBlock + ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock ASSERT_EQ(6UL, count1); } diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h index 328ef376c152e..79e9e2bcc9e62 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h @@ -32,7 +32,7 @@ class TensorRTSubGraphPass : public DataFlowGraphPass { TensorRTSubGraphPass(const NodeInsideSubgraphTeller& teller); - bool Initialize() override { return true; } + bool Initialize(Argument* argument) override { return true; } // This class get a sub-graph as input and determine whether to transform this // sub-graph into TensorRT. diff --git a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc index 5251c86d6dbfc..d12dcf0d0fe7f 100644 --- a/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc +++ b/paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc @@ -59,7 +59,7 @@ TEST(TensorRTSubGraph, single_pass) { // Check the TRT op's block desc for (auto node : dfg.nodes.nodes()) { - if (node->IsFunctionBlock() ) { + if (node->IsFunctionBlock()) { } } } diff --git a/paddle/fluid/inference/analysis/ut_helper.h b/paddle/fluid/inference/analysis/ut_helper.h index 722fa99a48a5f..e7450b571c36c 100644 --- a/paddle/fluid/inference/analysis/ut_helper.h +++ b/paddle/fluid/inference/analysis/ut_helper.h @@ -41,7 +41,9 @@ static DataFlowGraph ProgramDescToDFG( const framework::proto::ProgramDesc& desc) { DataFlowGraph graph; FluidToDataFlowGraphPass pass; - pass.Initialize(desc); + Argument argument; + argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc)); + pass.Initialize(&argument); pass.Run(&graph); pass.Finalize(); return graph; @@ -49,9 +51,12 @@ static DataFlowGraph ProgramDescToDFG( class DFG_Tester : public ::testing::Test { protected: - void SetUp() override { desc = LoadProgramDesc(FLAGS_inference_model_dir); } + void SetUp() override { + auto desc = LoadProgramDesc(FLAGS_inference_model_dir); + argument.origin_program_desc.reset(new framework::proto::ProgramDesc(desc)); + } - framework::proto::ProgramDesc desc; + Argument argument; }; } // namespace analysis From b79b018ca5fcf5e9b41c4e4697bfb10a6d3fd15a Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 17 Jun 2018 03:32:03 +0000 Subject: [PATCH 11/17] finish pass manager --- .../fluid/inference/analysis/CMakeLists.txt | 5 +- .../analysis/data_flow_graph_to_fluid_pass.h | 5 ++ .../analysis/dfg_graphviz_draw_pass.h | 6 ++ .../analysis/fluid_to_data_flow_graph_pass.cc | 3 + .../analysis/fluid_to_data_flow_graph_pass.h | 5 ++ paddle/fluid/inference/analysis/helper.h | 5 ++ paddle/fluid/inference/analysis/pass.h | 7 +- .../fluid/inference/analysis/pass_manager.cc | 24 ++++++- .../fluid/inference/analysis/pass_manager.h | 59 +++++++++++------ .../inference/analysis/pass_manager_tester.cc | 64 ++++++++++++++++++- paddle/fluid/inference/analysis/ut_helper.h | 23 +++++-- 11 files changed, 176 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index daebeeca60463..c52ac3219f6da 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -4,7 +4,7 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph data_flow_graph_to_fluid_pass.cc tensorrt_subgraph_pass.cc dfg_graphviz_draw_pass.cc - DEPS paddle_fluid) + DEPS framework_proto) cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis) @@ -18,7 +18,7 @@ function (inference_analysis_test TARGET) cc_test(${TARGET} SRCS "${analysis_test_SRCS}" - DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid + DEPS analysis ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) endfunction(inference_analysis_test) @@ -29,3 +29,4 @@ inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_fl inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) +inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h index 612763b7dda07..cbb05f622cc29 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h @@ -35,6 +35,11 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass { void Run(DataFlowGraph *graph) override; + std::string repr() const override { return "DFG to fluid"; } + std::string description() const override { + return "Transform a DFG to a Fluid ProgramDesc"; + } + Pass *CreatePrinterPass(std::ostream &os, const std::string &banner) const override { return nullptr; diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h index 60773571dfc40..93ebff59ae969 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h @@ -48,9 +48,15 @@ class DFG_GraphvizDrawPass : public DataFlowGraphPass { DFG_GraphvizDrawPass(const Config &config) : config_(config) {} + bool Initialize(Argument *argument) override { return true; } void Run(DataFlowGraph *graph) override; bool Finalize() override { return Pass::Finalize(); } + std::string repr() const override { return "DFG graphviz drawer"; } + std::string description() const override { + return "Debug a DFG by draw with graphviz"; + } + private: // Path of the dot file to output. std::string GenDotPath() const { diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc index 24cbf949d006d..5f62eef52876a 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc @@ -26,6 +26,7 @@ bool FluidToDataFlowGraphPass::Initialize(Argument *argument) { ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc); PADDLE_ENFORCE(argument); if (!argument->main_dfg) { + LOG(INFO) << "Init DFG"; argument->main_dfg.reset(new DataFlowGraph); } desc_ = argument->origin_program_desc.get(); @@ -35,6 +36,8 @@ bool FluidToDataFlowGraphPass::Initialize(Argument *argument) { bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); } void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { + PADDLE_ENFORCE(graph); + PADDLE_ENFORCE(desc_); // insert vars std::unordered_map var2id; auto &main_block = desc_->blocks(framework::kRootBlockIndex); diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h index 29ee572c4dfa5..56bb8c8b3cf41 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h @@ -41,6 +41,11 @@ class FluidToDataFlowGraphPass final : public DataFlowGraphPass { void Run(DataFlowGraph *graph) override; + std::string repr() const override { return "fluid-to-data-flow-graph"; } + std::string description() const override { + return "transform a fluid ProgramDesc to a data flow graph."; + } + Pass *CreatePrinterPass(std::ostream &os, const std::string &banner) const override; diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index 7828c89e85eab..cd7cdf416aef3 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -108,6 +108,11 @@ class OrderedRegistry { std::vector> data_; }; +template +std::unique_ptr make_unique(Args &&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass.h b/paddle/fluid/inference/analysis/pass.h index 1978aef4af287..65632b749177a 100644 --- a/paddle/fluid/inference/analysis/pass.h +++ b/paddle/fluid/inference/analysis/pass.h @@ -31,7 +31,7 @@ namespace analysis { class Pass { public: Pass() = default; - virtual ~Pass() {} + virtual ~Pass() = default; // Virtual method overridden by subclasses to do only necessary initialization // before any pass is run. // virtual bool Initialize() { return false; } @@ -68,6 +68,11 @@ class Pass { virtual void Run(FunctionBlock *x) { LOG(FATAL) << "not valid"; } // Run on a single DataFlowGraph. virtual void Run(DataFlowGraph *x) { LOG(FATAL) << "not valid"; } + + // Human-readable short representation. + virtual std::string repr() const = 0; + // Human-readable long description. + virtual std::string description() const = 0; }; // NodePass process on any Node types. diff --git a/paddle/fluid/inference/analysis/pass_manager.cc b/paddle/fluid/inference/analysis/pass_manager.cc index b07f2e50eacfd..b17c0e0d724eb 100644 --- a/paddle/fluid/inference/analysis/pass_manager.cc +++ b/paddle/fluid/inference/analysis/pass_manager.cc @@ -17,6 +17,28 @@ limitations under the License. */ namespace paddle { namespace inference { -namespace analysis {} // namespace analysis +namespace analysis { + +void DfgPassManager::RunAll() { + PADDLE_ENFORCE(argument_); + for (auto& pass : data_) { + VLOG(4) << "Running pass [" << pass->repr() << "]"; + pass->Run(argument_->main_dfg.get()); + } +} + +void NodePassManager::RunAll() { + PADDLE_ENFORCE(argument_); + PADDLE_ENFORCE(argument_->main_dfg.get()); + auto trait = + GraphTraits(argument_->main_dfg.get()).nodes_in_DFS(); + for (auto& node : trait) { + for (auto& pass : data_) { + pass->Run(&node); + } + } +} + +} // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/pass_manager.h b/paddle/fluid/inference/analysis/pass_manager.h index 850b02be04223..7841c4b9d0800 100644 --- a/paddle/fluid/inference/analysis/pass_manager.h +++ b/paddle/fluid/inference/analysis/pass_manager.h @@ -45,16 +45,36 @@ namespace analysis { */ class PassManager : public OrderedRegistry { public: + PassManager() = default; // Call all the passes' Initialize methods. The desc and data_flow_graph are // globally shared, so pass them as the arguemnts for all the pass managers. - virtual bool Initialize(Argument* argument) { return false; } virtual bool Initialize(const Argument& argument) { return false; } - // Run all the passes. - virtual void RunAll() = 0; + virtual bool Initialize(Argument* argument) { + argument_ = argument; + for (auto& pass : data_) { + LOG(INFO) << "Initializing pass " << pass->repr(); + if (!pass->Initialize(argument)) { + LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]"; + return false; + } + } + return true; + } // Call all the passes' Finalize methods. - virtual bool Finalize() = 0; + virtual bool Finalize() { + for (auto& pass : data_) { + if (!pass->Finalize()) { + LOG(ERROR) << "Failed to finalize pass [" << pass->repr() << "]"; + return false; + } + } + return true; + } + + // Run all the passes. + virtual void RunAll() = 0; // Short identifier. virtual std::string repr() const = 0; @@ -62,6 +82,9 @@ class PassManager : public OrderedRegistry { virtual std::string description() const = 0; virtual ~PassManager() = default; + + protected: + Argument* argument_{nullptr}; }; /* @@ -69,25 +92,23 @@ class PassManager : public OrderedRegistry { */ class DfgPassManager : public PassManager { public: - DfgPassManager(); - bool Initialize(Argument* argument) override { - for (auto& pass : data_) { - PADDLE_ENFORCE(pass->Initialize(argument)); - } - return true; - } + DfgPassManager() = default; void RunAll() override; - bool Finalize() override { - for (auto& pass : data_) { - pass->Finalize(); - } - return true; - } + virtual ~DfgPassManager() = default; +}; + +/* + * A pass manager that process a Node each time. + */ +class NodePassManager : public PassManager { + public: + NodePassManager() = default; + + void RunAll() override; - private: - DataFlowGraph* graph_; + virtual ~NodePassManager() = default; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/pass_manager_tester.cc b/paddle/fluid/inference/analysis/pass_manager_tester.cc index b2a71c0d12af1..7af6a19951463 100644 --- a/paddle/fluid/inference/analysis/pass_manager_tester.cc +++ b/paddle/fluid/inference/analysis/pass_manager_tester.cc @@ -13,11 +13,73 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/analysis/pass_manager.h" +#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" +#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" +#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" +#include "paddle/fluid/inference/analysis/ut_helper.h" #include namespace paddle { namespace inference { -namespace analysis {} // namespace analysis +namespace analysis { + +class TestDfgPassManager final : public DfgPassManager { + public: + TestDfgPassManager() = default; + virtual ~TestDfgPassManager() = default; + // Short identifier. + std::string repr() const override { return "test-pass-manager"; } + // Long description. + std::string description() const override { return "test doc"; } +}; + +class TestNodePassManager final : public NodePassManager { + public: + virtual ~TestNodePassManager() = default; + + std::string repr() const override { return "test-node-pass-manager"; } + std::string description() const override { return "test doc"; } +}; + +class TestNodePass final : public NodePass { + public: + virtual ~TestNodePass() = default; + + bool Initialize(Argument* argument) override { return true; } + + void Run(Node* node) override { + LOG(INFO) << "- Processing node " << node->repr(); + } + + std::string repr() const override { return "test-node"; } + std::string description() const override { return "some doc"; } +}; + +TEST_F(DFG_Tester, DFG_pass_manager) { + TestDfgPassManager manager; + DFG_GraphvizDrawPass::Config config("./", "dfg.dot"); + + manager.Register("fluid-to-flow-graph", new FluidToDataFlowGraphPass); + manager.Register("graphviz", new DFG_GraphvizDrawPass(config)); + manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass); + + ASSERT_TRUE(manager.Initialize(&argument)); + manager.RunAll(); +} + +TEST_F(DFG_Tester, Node_pass_manager) { + // Pre-process: initialize the DFG with the ProgramDesc first. + FluidToDataFlowGraphPass pass0; + pass0.Initialize(&argument); + pass0.Run(argument.main_dfg.get()); + + TestNodePassManager manager; + manager.Register("test-node-pass", new TestNodePass); + ASSERT_TRUE(manager.Initialize(&argument)); + manager.RunAll(); +} + +} // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/ut_helper.h b/paddle/fluid/inference/analysis/ut_helper.h index e7450b571c36c..ce1191a567a41 100644 --- a/paddle/fluid/inference/analysis/ut_helper.h +++ b/paddle/fluid/inference/analysis/ut_helper.h @@ -15,26 +15,37 @@ limitations under the License. */ #pragma once #include #include +#include #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/ut_helper.h" -#include "paddle/fluid/inference/io.h" namespace paddle { namespace inference { + +// Read ProgramDesc from a __model__ file, defined in io.cc +extern void ReadBinaryFile(const std::string& filename, std::string* contents); + namespace analysis { DEFINE_string(inference_model_dir, "", "inference test model dir"); static framework::proto::ProgramDesc LoadProgramDesc( const std::string& model_dir = FLAGS_inference_model_dir) { - paddle::platform::CPUPlace place; - paddle::framework::Executor executor(place); - paddle::framework::Scope scope; - auto program = Load(&executor, &scope, model_dir); - return *program->Proto(); + std::string msg; + std::string net_file = FLAGS_inference_model_dir + "/__model__"; + std::ifstream fin(net_file, std::ios::in | std::ios::binary); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", net_file); + fin.seekg(0, std::ios::end); + msg.resize(fin.tellg()); + fin.seekg(0, std::ios::beg); + fin.read(&(msg.at(0)), msg.size()); + fin.close(); + framework::proto::ProgramDesc program_desc; + program_desc.ParseFromString(msg); + return program_desc; } static DataFlowGraph ProgramDescToDFG( From 98be331a3bd80e036760a3452202b6303a8e4003 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 17 Jun 2018 04:03:00 +0000 Subject: [PATCH 12/17] disable tensorrt pass --- paddle/fluid/inference/analysis/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index c52ac3219f6da..2bb2c8135d8c3 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -28,5 +28,5 @@ inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_ inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) -inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) +#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) From 70f09aa5495107a7d45c142ae4c19bcf8e40a67e Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 17 Jun 2018 05:49:34 +0000 Subject: [PATCH 13/17] fix tests --- .../fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc | 3 ++- paddle/fluid/operators/tensorrt_engine_op_test.cc | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc index 0ce35cb974676..f4b5c5fd2201c 100644 --- a/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc +++ b/paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc @@ -39,7 +39,8 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { while (std::getline(file, line)) { no++; } - ASSERT_EQ(no, 82); + // DFG is sensitive to ProgramDesc, be careful to change the existing models. + ASSERT_EQ(no, 112); } } // namespace analysis diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 85330958cdba9..3a2fef48052ae 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -240,7 +240,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { } // Test with a larger FC layer. -TEST(TensorRTEngineOp, fc) { Execute(40, 256, 256); } +TEST(TensorRTEngineOp, fc) { Execute(40, 28, 28); } } // namespace operators } // namespace paddle From 12a3b27fa885ca320880093a540d3e3e8be8f7f5 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Mon, 18 Jun 2018 03:52:18 +0000 Subject: [PATCH 14/17] fix tests --- paddle/fluid/inference/analysis/argument.h | 10 +++++----- .../inference/analysis/fluid_to_data_flow_graph_pass.h | 2 +- paddle/fluid/inference/analysis/helper.h | 5 ----- paddle/fluid/inference/analysis/node.h | 3 --- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 4b46eef76921a..0a40f68498d4e 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -41,11 +41,11 @@ struct Argument { std::unique_ptr origin_program_desc; }; -#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ - if (!(field__)) { \ - LOG(ERROR) << "field " << #field__ << " should be set."; \ - return false; \ - } +#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ + if (__builtin_expect(static_cast(condition), 0))) { \ + LOG(ERROR) << "field " << #field__ << " should be set."; \ + return false; \ + } } // namespace analysis } // namespace inference diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h index 56bb8c8b3cf41..176faf0220cc9 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h @@ -35,7 +35,7 @@ namespace analysis { class FluidToDataFlowGraphPass final : public DataFlowGraphPass { public: FluidToDataFlowGraphPass() = default; - // bool Initialize(const framework::proto::ProgramDesc &desc) override; + bool Initialize(Argument *argument) override; bool Finalize() override; diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index 9037a0e6cdaec..f0039e113159f 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -110,11 +110,6 @@ class OrderedRegistry { std::vector> data_; }; -template -std::unique_ptr make_unique(Args &&... args) { - return std::unique_ptr(new T(std::forward(args)...)); -} - template T &GetFromScope(const framework::Scope &scope, const std::string &name) { framework::Variable *var = scope.FindVar(name); diff --git a/paddle/fluid/inference/analysis/node.h b/paddle/fluid/inference/analysis/node.h index fa0afe4a6bbfe..8c2e6d88b9605 100644 --- a/paddle/fluid/inference/analysis/node.h +++ b/paddle/fluid/inference/analysis/node.h @@ -150,9 +150,6 @@ class Node { Type type_{Type::kNone}; // Mark this node is deleted by some pass. bool deleted_{false}; - - // void *extra_info_; - mutable std::unordered_map attrs_; }; From d64aa80f2dc717fdb86e39ebbc055af754bc6a22 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Mon, 18 Jun 2018 03:59:19 +0000 Subject: [PATCH 15/17] fix tests --- paddle/fluid/inference/analysis/argument.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 0a40f68498d4e..a6e53e1107039 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -41,11 +41,12 @@ struct Argument { std::unique_ptr origin_program_desc; }; -#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ - if (__builtin_expect(static_cast(condition), 0))) { \ - LOG(ERROR) << "field " << #field__ << " should be set."; \ - return false; \ - } +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ + if (UNLIKELY(field__)) \ + LOG(ERROR) << "field " << #field__ << " should be set."; \ + return false; \ + } } // namespace analysis } // namespace inference From f29b55e09c2637731c8f0fd8fed3d868ccd2be34 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Mon, 18 Jun 2018 04:54:20 +0000 Subject: [PATCH 16/17] fix tests --- paddle/fluid/inference/analysis/argument.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index a6e53e1107039..93181c481b1ac 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -43,9 +43,9 @@ struct Argument { #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ - if (UNLIKELY(field__)) \ + if (UNLIKELY(field__)) { \ LOG(ERROR) << "field " << #field__ << " should be set."; \ - return false; \ + return false; \ } } // namespace analysis From e99d4516f7e0b53aa30eac0c494023e89cb38f15 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Mon, 18 Jun 2018 07:32:17 +0000 Subject: [PATCH 17/17] fix tests --- paddle/fluid/inference/analysis/argument.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 93181c481b1ac..7d7131ed7a188 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -43,7 +43,7 @@ struct Argument { #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ - if (UNLIKELY(field__)) { \ + if (!UNLIKELY(field__)) { \ LOG(ERROR) << "field " << #field__ << " should be set."; \ return false; \ }