-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add backward implementation for LSTM operator. #5115
Conversation
…he activation function pointer. It will be fixed later.
paddle/operators/lstm_op.cc
Outdated
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.") | ||
.AsDispensable(); | ||
AddOutput("Hidden", | ||
"(LoDTensor) the hidden state lod tensor of LSTM operator. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the hidden state of LSTM operator,中间的lod tensor多余?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Remove lod tensor and fix the shape info.
paddle/operators/lstm_op.cc
Outdated
"(LoDTensor) the hidden state lod tensor of LSTM operator. " | ||
"The shape and lod is the same with the `Input`."); | ||
AddOutput("Cell", | ||
"(LoDTensor) the cell state lod tensor of LSTM operator. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.cc
Outdated
"The shape and lod is the same with the `Input`."); | ||
AddOutput("BatchCellPreAct", | ||
"(LoDTensor) This LoDTensor is get in the forward and used " | ||
"in the backward.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get -> got
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.h
Outdated
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct"); | ||
|
||
auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden")); | ||
// auto* cell_g = ctx.Input<LoDTensor>(framework::GradVarName("Cell")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
158行可以删掉?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.h
Outdated
// auto bias_g_e = EigenVector<T>::Flatten(bias_mat); | ||
// auto gate_g_e = EigenMatrix<T>::From(batch_gate_g); | ||
// Eigen::array<int, 1> dims{{0}}; | ||
// bias_g_e.device(ctx.GetEigenDevice<Place>()) = gate_g_e.sum(dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
295-304行是TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are the equivalent code by Eigen, but the Eigen does not support double type on GPU device, so use GEMV. And I remove these lines.
paddle/operators/lstm_op.h
Outdated
lstm_grad.gateGrad = gate_g.data<T>(); | ||
lstm_grad.outputGrad = out_g.data<T>(); | ||
|
||
if (n != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (n != 0)
-》 if (n)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ASSERT_FLOAT_EQ(data_c[i], sum); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里看上去和cc的单测,很多代码都是一样的。后续会考虑共用么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续可以改进下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以建立一个issue,把URL贴在这里就好。这样日后不会忘记了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建立个issue: #5234
具体是哪些呢? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@luotao1 Thanks for your review.
The enhancements needed to do are updated in the comments: #5115 (comment)
And, there are TODO comments in the code.
paddle/operators/lstm_op.cc
Outdated
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.") | ||
.AsDispensable(); | ||
AddOutput("Hidden", | ||
"(LoDTensor) the hidden state lod tensor of LSTM operator. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Remove lod tensor and fix the shape info.
paddle/operators/lstm_op.cc
Outdated
"(LoDTensor) the hidden state lod tensor of LSTM operator. " | ||
"The shape and lod is the same with the `Input`."); | ||
AddOutput("Cell", | ||
"(LoDTensor) the cell state lod tensor of LSTM operator. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.cc
Outdated
"The shape and lod is the same with the `Input`."); | ||
AddOutput("BatchCellPreAct", | ||
"(LoDTensor) This LoDTensor is get in the forward and used " | ||
"in the backward.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.h
Outdated
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct"); | ||
|
||
auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden")); | ||
// auto* cell_g = ctx.Input<LoDTensor>(framework::GradVarName("Cell")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.h
Outdated
lstm_grad.gateGrad = gate_g.data<T>(); | ||
lstm_grad.outputGrad = out_g.data<T>(); | ||
|
||
if (n != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.h
Outdated
// auto bias_g_e = EigenVector<T>::Flatten(bias_mat); | ||
// auto gate_g_e = EigenMatrix<T>::From(batch_gate_g); | ||
// Eigen::array<int, 1> dims{{0}}; | ||
// bias_g_e.device(ctx.GetEigenDevice<Place>()) = gate_g_e.sum(dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are the equivalent code by Eigen, but the Eigen does not support double type on GPU device, so use GEMV. And I remove these lines.
ASSERT_FLOAT_EQ(data_c[i], sum); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续可以改进下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am far from an expert on this change, but it seems that it has stabilized for a while, so I approved it.
Fix #5114
some enhancements will be done in next PR.
Sigmoid
andTanh
) since there is a bug for activation function pointer. Will support to activations specified by users.