Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op types used in training with CINN. #36842

Merged
merged 18 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/paddle2cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper l
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS framework_proto graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)

cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler)
Expand Down
43 changes: 40 additions & 3 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <iterator>
#include <memory>
#include <regex>
#include <string>
#include <unordered_map>
#include <unordered_set>
Expand All @@ -25,6 +26,8 @@ limitations under the License. */

#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
Expand All @@ -34,6 +37,9 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

DECLARE_string(allow_cinn_ops);
DECLARE_string(deny_cinn_ops);

namespace paddle {
namespace framework {
namespace paddle2cinn {
Expand All @@ -46,6 +52,20 @@ using GraphNodeSet = std::unordered_set<Node*>;
using GraphNodeMap = std::unordered_map<Node*, Node*>;

namespace {
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops
// & FLAGS_deny_cinn_ops.
constexpr char kDelim[] = ";";

std::unordered_set<std::string> StringSplit(const std::string& str,
const std::string& delim) {
std::regex reg(delim);
std::unordered_set<std::string> elems{
std::sregex_token_iterator(str.begin(), str.end(), reg, -1),
std::sregex_token_iterator()};
elems.erase("");
return elems;
}

int ExtractOpRole(const GraphNodeSet& cluster) {
std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
Expand Down Expand Up @@ -340,9 +360,26 @@ void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
// to check whether the op node supported by CINN.
void SearchAllSubgraphs(Graph* graph) {
auto teller = [](const Node* node) {
return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) !=
nullptr;
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node->Name()) != nullptr;
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以放在外面吧?就不用每次都计算了

Suggested change
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto teller = [&allow_ops](const Node* node) { ... };

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (allow_ops.size()) {
return registered && allow_ops.count(node->Name());
}
// if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (deny_ops.size()) {
return registered && !deny_ops.count(node->Name());
}
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN
return registered;
};
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
std::vector<GraphNodeVec> clusters =
framework::ir::SubgraphDetector(graph, teller)();

Expand Down Expand Up @@ -375,7 +412,7 @@ void SearchAllSubgraphs(Graph* graph) {
// save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
VLOG(4) << "Compilation Key: " << compilation_key;
VLOG(4) << "Compilation Key:\n" << ReadableProtoStr(compilation_key);

// Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs,
Expand Down
52 changes: 40 additions & 12 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/enforce.h"

Expand All @@ -59,40 +61,60 @@ std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
ProgramDesc program;
GraphToProgram(*graph, &program);
program.Proto()->SerializeToString(&graph_key);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

序列化的字符串直接作为key比较冗余,查找效率低、还占空间。是否以其hash值作为key,CinnCompiler额外提供接口可以由key获取其子图的序列化字符串?

Copy link
Contributor Author

@wzzju wzzju Nov 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;

std::string作为std::unordered_map的key,查找时就是使用string的hash code查找的,你说不存储这个字符串,是想通过hash code回转得到愿字符串吗?这个恐怕是做不到的face_69@2x

Copy link
Contributor

@CtfGo CtfGo Nov 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是,查找时是string->hash吧,我意思是key存hash code,CinnCompiler提供接口由key获取graph->debug string替代ReadableProtoStr,不过这不重要。

if (!graphs_.count(graph_key)) {
graphs_[graph_key] = std::move(graph);
} else {
LOG(WARNING)
<< "The graph being added is already in CinnCompiler. Its key is:\n"
<< graph_key;
VLOG(4) << "Add a graph into CinnCompiler, which is:\n"
<< ReadableProtoStr(graph_key);
{
AutoWRLock guard{&rwlock_};
if (!graphs_.count(graph_key)) {
graphs_[graph_key] = std::move(graph);
} else {
LOG(WARNING)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会出现子图已经被注册的情况吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种情况不确定,如果只是运用build_cinn_pass一次,应该不会出现,这里已改为使用PADDLE_ENFORCE

<< "The graph being added is already in CinnCompiler. Its key is:\n"
<< ReadableProtoStr(graph_key);
}
}
return graph_key;
}

const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const {
AutoRDLock guard{&rwlock_};
PADDLE_ENFORCE_NE(
graphs_.count(graph_key), 0,
platform::errors::InvalidArgument("Can not find the target graph: %s",
graph_key.c_str()));
return *graphs_.at(graph_key);
platform::errors::InvalidArgument("Can not find the target graph:\n%s",
ReadableProtoStr(graph_key).c_str()));
const auto& graph = *graphs_.at(graph_key);
return graph;
}

const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
if (!cache_.count(cur_key)) {
bool exist = false;
{
AutoRDLock r_guard{&rwlock_};
exist = cache_.count(cur_key) != 0;
}
if (!exist) {
real_compiled_num_++;
cache_[cur_key] = CompileGraph(graph, input_tensors, target);
auto compiled_res = CompileGraph(graph, input_tensors, target);
AutoWRLock w_guard{&rwlock_};
if (!cache_.count(cur_key)) {
cache_[cur_key] = std::move(compiled_res);
}
}
return *cache_[cur_key];
AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_[cur_key];
return cached_boj;
}

const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
VLOG(4) << "The graph to be compiled is:\n"
<< ReadableProtoStr(compilation_key);
const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target);
}
Expand Down Expand Up @@ -125,6 +147,12 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
return compiled_obj;
}

std::string ReadableProtoStr(const std::string& bytes) {
proto::ProgramDesc program_desc;
program_desc.ParseFromString(bytes);
return program_desc.DebugString();
}

} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
13 changes: 13 additions & 0 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h"

Expand Down Expand Up @@ -64,6 +65,15 @@ class CinnCompiler {

const ir::Graph& FindGraph(const std::string& key) const;

void Clear() {
{
AutoWRLock guard{&rwlock_};
graphs_.clear();
cache_.clear();
}
real_compiled_num_ = 0;
}

std::int64_t real_compiled_num() const { return real_compiled_num_; }

~CinnCompiler() = default;
Expand All @@ -80,10 +90,13 @@ class CinnCompiler {
CinnCacheKey::Hash>
cache_;
std::atomic_int64_t real_compiled_num_{0};
mutable RWLock rwlock_;

DISABLE_COPY_AND_ASSIGN(CinnCompiler);
};

extern std::string ReadableProtoStr(const std::string& bytes);

} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
Loading