Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
gongweibao committed Nov 22, 2017
1 parent 53bd51e commit c30bfc6
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 30 deletions.
40 changes: 40 additions & 0 deletions paddle/operators/lrn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,46 @@ namespace operators {

using framework::Tensor;

template <typename T>
struct LRNFunctor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& ctx,
const framework::Tensor* input, int N, int C, int H, int W,
int n, T alpha, T beta, T k, framework::Tensor* mid,
framework::Tensor* out) {
auto x_v = framework::EigenVector<T>::Flatten(*input);

const int start = -(n - 1) / 2;
const int end = start + n;

auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid.device(ctx.GetEigenDevice<Place>()) = e_mid.constant(k);

auto e_x = framework::EigenTensor<T, 4>::From(*input);
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
for (int c = start; c <= end; c++) {
int ch = i + c;
if (ch >= 0 && ch < C) {
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));

auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));

s.device(ctx.GetEigenDevice<Place>()) += alpha * r.square();
}
}
}
}

auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(ctx.GetEigenDevice<Place>()) =
x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
}
};
template struct LRNFunctor<platform::CPUPlace, float>;
template struct LRNFunctor<platform::CPUPlace, double>;

class LRNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand Down
82 changes: 81 additions & 1 deletion paddle/operators/lrn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,88 @@
#define EIGEN_USE_GPU
#include "paddle/operators/lrn_op.h"

namespace ops = paddle::operators;
namespace paddle {
namespace operators {

template <typename T>
__global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
int H, int W, int size, T k, T alpha) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < img_size) {
const int w = idx % W;
const int h = (idx / W) % H;
const int n = idx / W / H;
const int offset = (n * C * H + h) * W + w;

in += offset;
mid += offset;
const int step = H * W;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;

T accum = 0;
int index = 0;
while (index < C + post_pad) {
if (index < C) {
accum += in[index * step] * in[index * step];
}
if (index >= size) {
accum -= in[(index - size) * step] * in[(index - size) * step];
}
if (index >= post_pad) {
scale[(index - post_pad) * step] = k + accum * alpha;
}
++index;
}
}
}

template <typename T>
__global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid,
T negative_beta, T* out) {
const int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < input_size) {
out[index] = in[index] * pow(mid[index], negative_beta);
}
}

template <typename T>
void CrossMapNormal(const T* inputs, T* outputs, T* mid, int N, int C, int H,
int W, int n, T k, T alpha, T beta) {
int img_size = N * H * W;
int block_size = 1024;
int grid_size = (img_size + 1024 - 1) / 1024;

auto& stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx.device_context())
.stream();
KeCMRNormFillScale<T><<<grid_size, block_size, 0,stream>>>
img_size, inputs, mid, C, H, W, n, k, alpha);

int input_size = N * H * W * C;
blockSize = 1024;
gridSize = (input_size + 1024 - 1) / 1024;
KeCMRNormOutput<T><<<grid_size, block_size, 0, stream>>>(
input_size, inputs, mid, -beta, outputs);
}

template <typename T>
struct LRNFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& ctx,
const framework::Tensor* input, framework::Tensor* mid,
framework::Tensor* out int N, int C, int H, int W, int n, T k,
T alpha, T beta) {
CrossMapNormal(input->data<T>(), out->mutable_data<T>(platform::GPUPlace),
mid->mutable_data<T>(platform::GPUPlace), N, C, H, W, n, k,
alpha, beta);
}

template struct LRNFunctor<platform::GPUPlace, float>;
template struct LRNFunctor<platform::GPUPlace, double>;
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(lrn_grad,
ops::LRNGradKernel<paddle::platform::GPUPlace, float>);
39 changes: 10 additions & 29 deletions paddle/operators/lrn_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
namespace paddle {
namespace operators {

template <typename Place, typename T>
struct LRNFunctor {
void operator()(const platform::DeviceContext& ctx,
const framework::Tensor* input, int N, int C, int H, int W,
int n, T alpha, T beta, T k, framework::Tensor* mid,
framework::Tensor* out);
};

template <typename Place, typename T>
class LRNKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -57,35 +65,8 @@ class LRNKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");

auto x_v = framework::EigenVector<T>::Flatten(*x);

const int start = -(n - 1) / 2;
const int end = start + n;

auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid.device(ctx.GetEigenDevice<Place>()) = e_mid.constant(k);

auto e_x = framework::EigenTensor<T, 4>::From(*x);
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
for (int c = start; c <= end; c++) {
int ch = i + c;
if (ch >= 0 && ch < C) {
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));

auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));

s.device(ctx.GetEigenDevice<Place>()) += alpha * r.square();
}
}
}
}

auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(ctx.GetEigenDevice<Place>()) =
x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
LRNFunctor<Place, T> f;
f(x, N, C, H, W, n, alpha, beta, k, mid, out);
}
};

Expand Down

0 comments on commit c30bfc6

Please sign in to comment.