-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op types used in training with CINN. #36842
Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op types used in training with CINN. #36842
Changes from 9 commits
af86c51
f4844b7
1d13d38
0dc23a7
10b3b5d
91065d4
8cb6b19
6560f06
a2d9398
607a3cf
442422e
0f86f5f
ee88c06
c517414
2d87862
88bd65d
53e8a37
9433f0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ limitations under the License. */ | |
#include <algorithm> | ||
#include <iterator> | ||
#include <memory> | ||
#include <regex> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
@@ -25,6 +26,8 @@ limitations under the License. */ | |
|
||
#include "cinn/frontend/op_mapper_registry.h" | ||
#include "cinn/frontend/op_mappers/use_op_mappers.h" | ||
#include "gflags/gflags.h" | ||
#include "glog/logging.h" | ||
#include "paddle/fluid/framework/ir/graph.h" | ||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" | ||
#include "paddle/fluid/framework/ir/node.h" | ||
|
@@ -34,6 +37,9 @@ limitations under the License. */ | |
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/fluid/platform/errors.h" | ||
|
||
DECLARE_string(allow_cinn_ops); | ||
DECLARE_string(deny_cinn_ops); | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace paddle2cinn { | ||
|
@@ -46,6 +52,20 @@ using GraphNodeSet = std::unordered_set<Node*>; | |
using GraphNodeMap = std::unordered_map<Node*, Node*>; | ||
|
||
namespace { | ||
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops | ||
// & FLAGS_deny_cinn_ops. | ||
constexpr char kDelim[] = ";"; | ||
|
||
std::unordered_set<std::string> StringSplit(const std::string& str, | ||
const std::string& delim) { | ||
std::regex reg(delim); | ||
std::unordered_set<std::string> elems{ | ||
std::sregex_token_iterator(str.begin(), str.end(), reg, -1), | ||
std::sregex_token_iterator()}; | ||
elems.erase(""); | ||
return elems; | ||
} | ||
|
||
int ExtractOpRole(const GraphNodeSet& cluster) { | ||
std::unordered_set<int> op_roles; | ||
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName(); | ||
|
@@ -340,9 +360,26 @@ void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster, | |
// to check whether the op node supported by CINN. | ||
void SearchAllSubgraphs(Graph* graph) { | ||
auto teller = [](const Node* node) { | ||
return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) != | ||
nullptr; | ||
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( | ||
node->Name()) != nullptr; | ||
// if the op type is registered in CINN and allow_ops is not empty, return | ||
// true only when it is in allow_ops | ||
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); | ||
if (allow_ops.size()) { | ||
return registered && allow_ops.count(node->Name()); | ||
} | ||
// if the op type is registered in CINN and deny_ops is not empty, return | ||
// true only when it is not in deny_ops | ||
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if (deny_ops.size()) { | ||
return registered && !deny_ops.count(node->Name()); | ||
} | ||
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, | ||
// return true only when it is registered in CINN | ||
return registered; | ||
}; | ||
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; | ||
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; | ||
std::vector<GraphNodeVec> clusters = | ||
framework::ir::SubgraphDetector(graph, teller)(); | ||
|
||
|
@@ -375,7 +412,7 @@ void SearchAllSubgraphs(Graph* graph) { | |
// save it in CinnCompiler | ||
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph( | ||
cluster_set, cluster_internals, cluster_inputs, cluster_outputs)); | ||
VLOG(4) << "Compilation Key: " << compilation_key; | ||
VLOG(4) << "Compilation Key:\n" << ReadableProtoStr(compilation_key); | ||
|
||
// Replace the found cluster to a new cinn op node | ||
ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,11 +29,13 @@ | |
#include "cinn/hlir/framework/graph_compiler.h" | ||
#include "cinn/hlir/framework/pass.h" | ||
#include "cinn/hlir/pass/use_pass.h" | ||
#include "paddle/fluid/framework/framework.pb.h" | ||
#include "paddle/fluid/framework/ir/graph.h" | ||
#include "paddle/fluid/framework/ir/graph_helper.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/framework/rw_lock.h" | ||
#include "paddle/fluid/framework/tensor.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
|
@@ -59,40 +61,60 @@ std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { | |
ProgramDesc program; | ||
GraphToProgram(*graph, &program); | ||
program.Proto()->SerializeToString(&graph_key); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 序列化的字符串直接作为key比较冗余,查找效率低、还占空间。是否以其hash值作为key,CinnCompiler额外提供接口可以由key获取其子图的序列化字符串? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不是,查找时是string->hash吧,我意思是key存hash code,CinnCompiler提供接口由key获取graph->debug string替代ReadableProtoStr,不过这不重要。 |
||
if (!graphs_.count(graph_key)) { | ||
graphs_[graph_key] = std::move(graph); | ||
} else { | ||
LOG(WARNING) | ||
<< "The graph being added is already in CinnCompiler. Its key is:\n" | ||
<< graph_key; | ||
VLOG(4) << "Add a graph into CinnCompiler, which is:\n" | ||
<< ReadableProtoStr(graph_key); | ||
{ | ||
AutoWRLock guard{&rwlock_}; | ||
if (!graphs_.count(graph_key)) { | ||
graphs_[graph_key] = std::move(graph); | ||
} else { | ||
LOG(WARNING) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 会出现子图已经被注册的情况吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这种情况不确定,如果只是运用 |
||
<< "The graph being added is already in CinnCompiler. Its key is:\n" | ||
<< ReadableProtoStr(graph_key); | ||
} | ||
} | ||
return graph_key; | ||
} | ||
|
||
const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const { | ||
AutoRDLock guard{&rwlock_}; | ||
PADDLE_ENFORCE_NE( | ||
graphs_.count(graph_key), 0, | ||
platform::errors::InvalidArgument("Can not find the target graph: %s", | ||
graph_key.c_str())); | ||
return *graphs_.at(graph_key); | ||
platform::errors::InvalidArgument("Can not find the target graph:\n%s", | ||
ReadableProtoStr(graph_key).c_str())); | ||
const auto& graph = *graphs_.at(graph_key); | ||
return graph; | ||
} | ||
|
||
const CinnCompiledObject& CinnCompiler::Compile( | ||
const Graph& graph, | ||
const std::map<std::string, const LoDTensor*>& input_tensors, | ||
const Target& target) { | ||
CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); | ||
if (!cache_.count(cur_key)) { | ||
bool exist = false; | ||
{ | ||
AutoRDLock r_guard{&rwlock_}; | ||
exist = cache_.count(cur_key) != 0; | ||
} | ||
if (!exist) { | ||
real_compiled_num_++; | ||
cache_[cur_key] = CompileGraph(graph, input_tensors, target); | ||
auto compiled_res = CompileGraph(graph, input_tensors, target); | ||
AutoWRLock w_guard{&rwlock_}; | ||
if (!cache_.count(cur_key)) { | ||
cache_[cur_key] = std::move(compiled_res); | ||
} | ||
} | ||
return *cache_[cur_key]; | ||
AutoRDLock guard{&rwlock_}; | ||
const auto& cached_boj = *cache_[cur_key]; | ||
return cached_boj; | ||
} | ||
|
||
const CinnCompiledObject& CinnCompiler::Compile( | ||
const std::string& compilation_key, | ||
const std::map<std::string, const LoDTensor*>& input_tensors, | ||
const Target& target) { | ||
VLOG(4) << "The graph to be compiled is:\n" | ||
<< ReadableProtoStr(compilation_key); | ||
const auto& graph = FindGraph(compilation_key); | ||
return Compile(graph, input_tensors, target); | ||
} | ||
|
@@ -125,6 +147,12 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( | |
return compiled_obj; | ||
} | ||
|
||
std::string ReadableProtoStr(const std::string& bytes) { | ||
proto::ProgramDesc program_desc; | ||
program_desc.ParseFromString(bytes); | ||
return program_desc.DebugString(); | ||
} | ||
|
||
} // namespace paddle2cinn | ||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以放在外面吧?就不用每次都计算了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.