-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support testing when training and handle dropout and batch_norm operator in testing mode #5734
Conversation
paddle/framework/prune.cc
Outdated
op_desc.type() == kBatchNormOpType) { | ||
for (auto& attr : *op_desc.mutable_attrs()) { | ||
if (attr.name() == "is_test") { | ||
attr.set_b(true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
break;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/framework/prune.cc
Outdated
void Prune(const ProgramDesc& input, ProgramDesc* output) { | ||
prune_impl(input, output, 0); | ||
void Prune(const ProgramDesc& input, ProgramDesc* output, bool is_test) { | ||
prune_impl(input, output, 0, is_test); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think setting attribute is_test
true has no relationship with prune
. We shall add an independent function or module to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now have prune
method to transform ProgramDesc and generate another ProgramDesc.
Exactly, we will do some optimization in inference stage and handle some special operators. Then, we will generate an inference ProgramDesc.
You are right. Maybe we need another method like prune
, the interface could be:
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output);
@@ -106,5 +108,26 @@ void Prune(const ProgramDesc& input, ProgramDesc* output) { | |||
prune_impl(input, output, 0); | |||
} | |||
|
|||
void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic of inference_optimize_impl
is quite simple. Maybe we can implement it in Python.
Fix #5733 and Fix #5814