-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LRN efficient GPU implement. #5894
Conversation
paddle/operators/lrn_op.cc
Outdated
template <typename T> | ||
struct LRNFunctor<platform::CPUPlace, T> { | ||
void operator()(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* input, framework::Tensor* out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For input arguments: const framework::Tensor&
https://google.github.io/styleguide/cppguide.html#Reference_Arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
paddle/operators/lrn_op.cc
Outdated
const int end = start + n; | ||
|
||
auto e_mid = framework::EigenTensor<T, 4>::From(*mid); | ||
e_mid.device(ctx.GetEigenDevice<platform::CPUPlace>()) = e_mid.constant(k); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the CPU implementation of Eigen, there is no need to use .device()
.
e_mid.setConstant(k);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
paddle/operators/lrn_op.cc
Outdated
Eigen::array<int, 4>({{1, 1, H, W}})); | ||
|
||
s.device(ctx.GetEigenDevice<platform::CPUPlace>()) += | ||
alpha * r.square(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same as above:
s += alpha * r.square();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
paddle/operators/lrn_op.cc
Outdated
|
||
auto out_e = framework::EigenVector<T>::Flatten(*out); | ||
out_e.device(ctx.GetEigenDevice<platform::CPUPlace>()) = | ||
x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
paddle/operators/lrn_op.cc
Outdated
void operator()(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* out, | ||
const framework::Tensor* mid, framework::Tensor* x_g, | ||
const framework::Tensor* out_g, int N, int C, int H, int W, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the input arguments, the same as above comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
paddle/operators/lrn_op.cu
Outdated
T alpha, T beta) { | ||
int img_size = N * H * W; | ||
int block_size = 1024; | ||
int grid_size = (img_size + 1024 - 1) / 1024; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用block_size替代line 69中的1024.
int grid_size = (img_size + block_size - 1) / block_size;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
paddle/operators/lrn_op.cu
Outdated
|
||
int input_size = N * H * W * C; | ||
block_size = 1024; | ||
grid_size = (input_size + 1024 - 1) / 1024; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,用block_size替代line 79中的1024.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
paddle/operators/lrn_op.cu
Outdated
} | ||
if (index >= size) { | ||
accum -= in[(index - size) * step] * in[(index - size) * step]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line 41和line 44中,可以利用寄存器先保存global内存中的数据,这样可以避免多次访问globle内存:
if (index < C) {
T val = in[index * step];
accum += val * val;
}
if (index >= size) {
T val = in[index - size) * step];
accum -= val * val;
}
paddle/operators/lrn_op.cu
Outdated
|
||
const auto& stream = | ||
reinterpret_cast<const platform::CUDADeviceContext&>(ctx.device_context()) | ||
.stream(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
paddle/operators/lrn_op.cu
Outdated
int img_size = N * H * W; | ||
|
||
int block_size = 1024; | ||
int grid_size = (img_size + 1024 - 1) / 1024; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Fix #5066