Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pass manager #11440

Merged
merged 20 commits into from
Jun 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions paddle/fluid/inference/analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
cc_library(analysis SRCS dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc fluid_to_data_flow_graph_pass.cc
DEPS paddle_fluid)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc
tensorrt_subgraph_pass.cc
dfg_graphviz_draw_pass.cc
DEPS framework_proto)
cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)

set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)

cc_test(test_data_flow_graph SRCS data_flow_graph_tester.cc DEPS analysis ${FLUID_CORE_MODULES} paddle_fluid
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
set_tests_properties(test_data_flow_graph PROPERTIES DEPENDS test_word2vec)
function (inference_analysis_test TARGET)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})

cc_test(test_subgraph_splitter
SRCS subgraph_splitter_tester.cc
DEPS analysis paddle_fluid tensor
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
set_tests_properties(test_subgraph_splitter PROPERTIES DEPENDS test_word2vec)
cc_test(${TARGET}
SRCS "${analysis_test_SRCS}"
DEPS analysis
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5)
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endfunction(inference_analysis_test)

cc_test(test_dfg_graphviz_draw_pass
SRCS dfg_graphviz_draw_pass_tester.cc
DEPS analysis
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
set_tests_properties(test_dfg_graphviz_draw_pass PROPERTIES DEPENDS test_word2vec)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
15 changes: 15 additions & 0 deletions paddle/fluid/inference/analysis/argument.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/analysis/argument.h"
53 changes: 53 additions & 0 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
* This file defines the class Argument, which is the input and output of the
* analysis module. All the fields that needed either by Passes or PassManagers
* are contained in Argument.
*
* TODO(Superjomn) Find some way better to contain the fields when it grow too
* big.
*/

#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"

namespace paddle {
namespace inference {
namespace analysis {

/*
* The argument definition of both Pass and PassManagers.
*
* All the fields should be registered here for clearness.
*/
struct Argument {
// The graph that process by the Passes or PassManagers.
std::unique_ptr<DataFlowGraph> main_dfg;

// The original program desc.
std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc;
};

#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \
if (!UNLIKELY(field__)) { \
LOG(ERROR) << "field " << #field__ << " should be set."; \
return false; \
}

} // namespace analysis
} // namespace inference
} // namespace paddle
15 changes: 2 additions & 13 deletions paddle/fluid/inference/analysis/data_flow_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/node.h"

namespace paddle {
namespace inference {
Expand Down Expand Up @@ -57,19 +58,7 @@ std::string DataFlowGraph::DotString() const {
// Add nodes
for (size_t i = 0; i < nodes.size(); i++) {
const Node &node = nodes.Get(i);
switch (node.type()) {
case Node::Type::kValue:
dot.AddNode(node.repr(), node.dot_attrs());
break;
case Node::Type::kFunction:
dot.AddNode(node.repr(), node.dot_attrs());
break;
case Node::Type::kFunctionBlock:
dot.AddNode(node.repr(), node.dot_attrs());
break;
default:
PADDLE_THROW("unsupported Node type %d", static_cast<int>(node.type()));
}
dot.AddNode(node.repr(), node.dot_attrs());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DotString is not a good name.
Maybe DotGraphgString or String is better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a Graph word in DataFlowGraph::DotString, I think DotString is more clear.

}

// Add edges
Expand Down
77 changes: 77 additions & 0 deletions paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/framework/proto_desc.h"

namespace paddle {
namespace inference {
namespace analysis {

bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
desc_ = argument->origin_program_desc.get();
// Here some logic from program_desc.cc and will not add new interfaces into
// framework::ProgramDesc class, use some UT to assure the correctness.
auto* block = desc_->mutable_blocks()->Add();
block->set_idx(framework::kRootBlockIndex);
block->set_parent_idx(framework::kNoneBlockIndex);
return true;
}

bool DataFlowGraphToFluidPass::Finalize() { return true; }

void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
auto traits = GraphTraits<DataFlowGraph>(graph);
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) {
if (it->deleted()) continue;
switch (it->type()) {
case Node::Type::kFunction:
LOG(INFO) << "add function " << it->name();
AddFluidOp(&(*it));
break;
case Node::Type::kFunctionBlock:
AddEngineOp(&(*it));
break;
default:
continue;
}
}
}

void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
LOG(INFO) << "processing func " << node->name();
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc());
// currently only the main block is analyzed.
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops();
LOG(INFO) << "to copy the op";
*op = *ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase.
// The inputs and outputs of the existing ops are not changed by tensorrt
// subgraph pass.
// NOTE It might be changed by other passes in the long run.
}

void DataFlowGraphToFluidPass::AddEngineOp(Node* node) {
// auto* ori_op = static_cast<framework::proto::OpDesc*>(node->extra_info());
// auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
// auto* op = main_block->add_ops();
// TODO(Superjomn) Here need to expose some arguments for default setting.
}

} // namespace analysis
} // namespace inference
} // namespace paddle
59 changes: 59 additions & 0 deletions paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

/*
* This file implements the transformation from fluid ProgramDesc to data flow
* graph.
*/

#pragma once

#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"

namespace paddle {
namespace inference {
namespace analysis {
class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
public:
DataFlowGraphToFluidPass() = default;

bool Initialize(Argument *argument) override;
bool Finalize() override;

void Run(DataFlowGraph *graph) override;

std::string repr() const override { return "DFG to fluid"; }
std::string description() const override {
return "Transform a DFG to a Fluid ProgramDesc";
}

Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const override {
return nullptr;
}

protected:
// Add a Fluid Op into the ProgramDesc.
void AddFluidOp(Node *node);
// Add a EngineOp into the ProgramDesc.
void AddEngineOp(Node *node);

private:
framework::proto::ProgramDesc *desc_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ namespace inference {
namespace analysis {

TEST_F(DFG_Tester, Test) {
framework::proto::ProgramDesc new_desc;
DataFlowGraph graph;

FluidToDataFlowGraphPass pass0;
DataFlowGraphToFluidPass pass1;
pass0.Initialize(desc);
pass1.Initialize(&new_desc);
ASSERT_TRUE(pass0.Initialize(&argument));
ASSERT_TRUE(pass1.Initialize(&argument));

pass0.Run(&graph);
pass1.Run(&graph);
Expand Down
54 changes: 54 additions & 0 deletions paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"

namespace paddle {
namespace inference {
namespace analysis {

void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto content = Draw(graph);
std::ofstream file(GenDotPath());
file.write(content.c_str(), content.size());
file.close();
LOG(INFO) << "draw dot to " << GenDotPath();
}

std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For simplicity, each pass takes a DataFlowGraph* as input, I will enhance this and protect the read operation in another PR latter.

Dot dot;
// Add nodes
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (config_.display_deleted_node || !node.deleted()) {
dot.AddNode(node.repr(), node.dot_attrs());
}
}
// Add edges
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &in : node.inlinks) {
if (!config_.display_deleted_node && in->deleted()) continue;
for (auto &in : node.inlinks) {
dot.AddEdge(in->repr(), node.repr(), {});
}
}
}
return dot.Build();
}

} // namespace analysis
} // namespace inference
} // namespace paddle
Loading