-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op types used in training with CINN. #36842
Conversation
Thanks for your contribution! |
add+relu
loss values (CINN).add+relu
loss values (CINN Test).
cinn_losses = train(self.tmpdir, "paddle") | ||
set_cinn_flag(False) | ||
pd_losses = train(self.tmpdir, "cinn") | ||
np.allclose(cinn_losses, pd_losses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.allclose(cinn_losses, pd_losses) | |
self.assertTrue(np.allclose(cinn_losses, pd_losses, atol=1e-5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
… update_cinn_id
… update_cinn_id
… update_cinn_id
add+relu
loss values (CINN Test).if (!graphs_.count(graph_key)) { | ||
graphs_[graph_key] = std::move(graph); | ||
} else { | ||
LOG(WARNING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会出现子图已经被注册的情况吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种情况不确定,如果只是运用build_cinn_pass
一次,应该不会出现,这里已改为使用PADDLE_ENFORCE
。
@@ -59,40 +61,60 @@ std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { | |||
ProgramDesc program; | |||
GraphToProgram(*graph, &program); | |||
program.Proto()->SerializeToString(&graph_key); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
序列化的字符串直接作为key比较冗余,查找效率低、还占空间。是否以其hash值作为key,CinnCompiler额外提供接口可以由key获取其子图的序列化字符串?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是,查找时是string->hash吧,我意思是key存hash code,CinnCompiler提供接口由key获取graph->debug string替代ReadableProtoStr,不过这不重要。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for the use of const_cast and op benchmark ci
auto* cinn_compiler = CinnCompiler::GetInstance(); | ||
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key); | ||
// viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph)); | ||
viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个测试生成的文件是不是应该注释掉?或者在函数最后删掉?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
方便后续调试,先不删除,ci上build目录生成的文件会被自动清理的。
} | ||
// if the op type is registered in CINN and deny_ops is not empty, return | ||
// true only when it is not in deny_ops | ||
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
node->Name()) != nullptr; | ||
// if the op type is registered in CINN and allow_ops is not empty, return | ||
// true only when it is in allow_ops | ||
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以放在外面吧?就不用每次都计算了
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); | |
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); | |
auto teller = [&allow_ops](const Node* node) { ... }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… update_cinn_id
0f86f5f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
FLAGS_allow_cinn_ops
和FLAGS_deny_cinn_ops
两个flag,用于控制Paddle训练中使用CINN算子代替原生算子的种类。CinnCompiler
类在多线程运行环境下的正确性(主要关于对cache_的操作)。CinnCompiler
类增加VizGraph
和ReadableKey
方法,用于打印子图的dot信息以及获取具备可读性的CompilationKey。test_parallel_executor_run_cinn.py
单测中增加对以下两种情况中add+relu
模型loss值的比较:add+relu
模型add+relu
模型