-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RNN backward create #3490
RNN backward create #3490
Conversation
… rnn-backward-create
… rnn-backward-create
paddle/operators/recurrent_op.h
Outdated
/* | ||
* Some special preprocesses after a gradient op is created. | ||
*/ | ||
static void Init(const RecurrentOp& op, RecurrentGradientOp* grad_op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Init is a bad pattern, put Init() into the constructor.
paddle/operators/recurrent_op.cc
Outdated
void RecurrentGradientOp::Init( | ||
const RecurrentOp& op, RecurrentGradientOp* grad_op, | ||
const std::unordered_set<std::string>& no_grad_vars) { | ||
auto gradop = Backward(op.stepnet(), no_grad_vars); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是递归调用的话,需要调用的是 BackwardRecursive
那个函数,而不是Backward函数,另外还有一些参数,比如uid之类的需要在递归中传递过去。
@@ -178,11 +179,24 @@ std::shared_ptr<OperatorBase> BackwardRecursive( | |||
return false; | |||
}); | |||
|
|||
// process recurrent gradient op as a special operator. | |||
if (forwardOp.Type() == "recurrent_op") { | |||
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just write create method here should be simplest.
resolve #3472