-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/pass manager #11440
Feature/pass manager #11440
Changes from all commits
2e69429
1292dbc
9776b3a
4ae9f53
4ff4ca0
c687067
c386909
30a3ebc
9daeed2
743f1f8
8db0964
1ab4dbb
b79b018
d598826
98be331
70f09aa
12a3b27
d64aa80
f29b55e
e99d451
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,32 @@ | ||
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init) | ||
cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc | ||
DEPS paddle_fluid) | ||
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc | ||
fluid_to_data_flow_graph_pass.cc | ||
data_flow_graph_to_fluid_pass.cc | ||
tensorrt_subgraph_pass.cc | ||
dfg_graphviz_draw_pass.cc | ||
DEPS framework_proto) | ||
cc_test(test_node SRCS node_tester.cc DEPS analysis) | ||
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) | ||
|
||
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) | ||
|
||
cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid | ||
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) | ||
set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec) | ||
function (inference_analysis_test TARGET) | ||
set(options "") | ||
set(oneValueArgs "") | ||
set(multiValueArgs SRCS) | ||
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) | ||
|
||
cc_test(test_subgraph_splitter | ||
SRCS subgraph_splitter_tester.cc | ||
DEPS analysis paddle_fluid tensor | ||
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) | ||
set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec) | ||
cc_test(${TARGET} | ||
SRCS "${analysis_test_SRCS}" | ||
DEPS analysis | ||
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5) | ||
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) | ||
endfunction(inference_analysis_test) | ||
|
||
cc_test(test_dfg_graphviz_draw_pass | ||
SRCS dfg_graphviz_draw_pass_tester.cc | ||
DEPS analysis | ||
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model) | ||
set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec) | ||
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) | ||
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) | ||
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) | ||
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) | ||
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) | ||
#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) | ||
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/inference/analysis/argument.h" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
/* | ||
* This file defines the class Argument, which is the input and output of the | ||
* analysis module. All the fields that needed either by Passes or PassManagers | ||
* are contained in Argument. | ||
* | ||
* TODO(Superjomn) Find some way better to contain the fields when it grow too | ||
* big. | ||
*/ | ||
|
||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/inference/analysis/data_flow_graph.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace analysis { | ||
|
||
/* | ||
* The argument definition of both Pass and PassManagers. | ||
* | ||
* All the fields should be registered here for clearness. | ||
*/ | ||
struct Argument { | ||
// The graph that process by the Passes or PassManagers. | ||
std::unique_ptr<DataFlowGraph> main_dfg; | ||
|
||
// The original program desc. | ||
std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc; | ||
}; | ||
|
||
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) | ||
#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ | ||
if (!UNLIKELY(field__)) { \ | ||
LOG(ERROR) << "field " << #field__ << " should be set."; \ | ||
return false; \ | ||
} | ||
|
||
} // namespace analysis | ||
} // namespace inference | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" | ||
#include "paddle/fluid/framework/proto_desc.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace analysis { | ||
|
||
bool DataFlowGraphToFluidPass::Initialize(Argument* argument) { | ||
ANALYSIS_ARGUMENT_CHECK_FIELD(argument) | ||
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) | ||
desc_ = argument->origin_program_desc.get(); | ||
// Here some logic from program_desc.cc and will not add new interfaces into | ||
// framework::ProgramDesc class, use some UT to assure the correctness. | ||
auto* block = desc_->mutable_blocks()->Add(); | ||
block->set_idx(framework::kRootBlockIndex); | ||
block->set_parent_idx(framework::kNoneBlockIndex); | ||
return true; | ||
} | ||
|
||
bool DataFlowGraphToFluidPass::Finalize() { return true; } | ||
|
||
void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) { | ||
auto traits = GraphTraits<DataFlowGraph>(graph); | ||
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) { | ||
if (it->deleted()) continue; | ||
switch (it->type()) { | ||
case Node::Type::kFunction: | ||
LOG(INFO) << "add function " << it->name(); | ||
AddFluidOp(&(*it)); | ||
break; | ||
case Node::Type::kFunctionBlock: | ||
AddEngineOp(&(*it)); | ||
break; | ||
default: | ||
continue; | ||
} | ||
} | ||
} | ||
|
||
void DataFlowGraphToFluidPass::AddFluidOp(Node* node) { | ||
LOG(INFO) << "processing func " << node->name(); | ||
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc()); | ||
// currently only the main block is analyzed. | ||
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); | ||
auto* op = main_block->add_ops(); | ||
LOG(INFO) << "to copy the op"; | ||
*op = *ori_op; // copy the attributes, by default, these will not be changed | ||
// by analysis phrase. | ||
// The inputs and outputs of the existing ops are not changed by tensorrt | ||
// subgraph pass. | ||
// NOTE It might be changed by other passes in the long run. | ||
} | ||
|
||
void DataFlowGraphToFluidPass::AddEngineOp(Node* node) { | ||
// auto* ori_op = static_cast<framework::proto::OpDesc*>(node->extra_info()); | ||
// auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); | ||
// auto* op = main_block->add_ops(); | ||
// TODO(Superjomn) Here need to expose some arguments for default setting. | ||
} | ||
|
||
} // namespace analysis | ||
} // namespace inference | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
/* | ||
* This file implements the transformation from fluid ProgramDesc to data flow | ||
* graph. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/inference/analysis/data_flow_graph.h" | ||
#include "paddle/fluid/inference/analysis/pass.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace analysis { | ||
class DataFlowGraphToFluidPass final : public DataFlowGraphPass { | ||
public: | ||
DataFlowGraphToFluidPass() = default; | ||
|
||
bool Initialize(Argument *argument) override; | ||
bool Finalize() override; | ||
|
||
void Run(DataFlowGraph *graph) override; | ||
|
||
std::string repr() const override { return "DFG to fluid"; } | ||
std::string description() const override { | ||
return "Transform a DFG to a Fluid ProgramDesc"; | ||
} | ||
|
||
Pass *CreatePrinterPass(std::ostream &os, | ||
const std::string &banner) const override { | ||
return nullptr; | ||
} | ||
|
||
protected: | ||
// Add a Fluid Op into the ProgramDesc. | ||
void AddFluidOp(Node *node); | ||
// Add a EngineOp into the ProgramDesc. | ||
void AddEngineOp(Node *node); | ||
|
||
private: | ||
framework::proto::ProgramDesc *desc_; | ||
}; | ||
} // namespace analysis | ||
} // namespace inference | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace analysis { | ||
|
||
void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) { | ||
auto content = Draw(graph); | ||
std::ofstream file(GenDotPath()); | ||
file.write(content.c_str(), content.size()); | ||
file.close(); | ||
LOG(INFO) << "draw dot to " << GenDotPath(); | ||
} | ||
|
||
std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Const DataFlowGraph& https://google.github.io/styleguide/cppguide.html#Reference_Arguments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For simplicity, each pass takes a |
||
Dot dot; | ||
// Add nodes | ||
for (size_t i = 0; i < graph->nodes.size(); i++) { | ||
const Node &node = graph->nodes.Get(i); | ||
if (config_.display_deleted_node || !node.deleted()) { | ||
dot.AddNode(node.repr(), node.dot_attrs()); | ||
} | ||
} | ||
// Add edges | ||
for (size_t i = 0; i < graph->nodes.size(); i++) { | ||
const Node &node = graph->nodes.Get(i); | ||
if (!config_.display_deleted_node && node.deleted()) continue; | ||
for (auto &in : node.inlinks) { | ||
if (!config_.display_deleted_node && in->deleted()) continue; | ||
for (auto &in : node.inlinks) { | ||
dot.AddEdge(in->repr(), node.repr(), {}); | ||
} | ||
} | ||
} | ||
return dot.Build(); | ||
} | ||
|
||
} // namespace analysis | ||
} // namespace inference | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DotString is not a good name.
Maybe DotGraphgString or String is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a
Graph
word inDataFlowGraph::DotString
, I thinkDotString
is more clear.