Skip to content

Commit

Permalink
Merge pull request BVLC#740 from qipeng/lrelu
Browse files Browse the repository at this point in the history
Leaky ReLU
  • Loading branch information
jeffdonahue committed Jul 22, 2014
2 parents 49a8ea3 + dde9790 commit dc8bcf3
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 9 deletions.
8 changes: 6 additions & 2 deletions src/caffe/layers/relu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ Dtype ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const int count = bottom[0]->count();
Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
for (int i = 0; i < count; ++i) {
top_data[i] = std::max(bottom_data[i], Dtype(0));
top_data[i] = std::max(bottom_data[i], Dtype(0))
+ negative_slope * std::min(bottom_data[i], Dtype(0));
}
return Dtype(0);
}
Expand All @@ -29,8 +31,10 @@ void ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
const int count = (*bottom)[0]->count();
Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * (bottom_data[i] > 0);
bottom_diff[i] = top_diff[i] * ((bottom_data[i] >= 0)
+ negative_slope * (bottom_data[i] < 0));
}
}
}
Expand Down
16 changes: 10 additions & 6 deletions src/caffe/layers/relu_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
namespace caffe {

template <typename Dtype>
__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out) {
__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out,
Dtype negative_slope) {
CUDA_KERNEL_LOOP(index, n) {
out[index] = in[index] > 0 ? in[index] : 0;
out[index] = in[index] > 0 ? in[index] : in[index] * negative_slope;
}
}

Expand All @@ -21,9 +22,10 @@ Dtype ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
// NOLINT_NEXT_LINE(whitespace/operators)
ReLUForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data);
count, bottom_data, top_data, negative_slope);
CUDA_POST_KERNEL_CHECK;
// << " count: " << count << " bottom_data: "
// << (unsigned long)bottom_data
Expand All @@ -35,9 +37,10 @@ Dtype ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,

template <typename Dtype>
__global__ void ReLUBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff) {
const Dtype* in_data, Dtype* out_diff, Dtype negative_slope) {
CUDA_KERNEL_LOOP(index, n) {
out_diff[index] = in_diff[index] * (in_data[index] > 0);
out_diff[index] = in_diff[index] * ((in_data[index] >= 0)
+ (in_data[index] < 0) * negative_slope);
}
}

Expand All @@ -50,9 +53,10 @@ void ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
// NOLINT_NEXT_LINE(whitespace/operators)
ReLUBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_data, bottom_diff);
count, top_diff, bottom_data, bottom_diff, negative_slope);
CUDA_POST_KERNEL_CHECK;
}
}
Expand Down
13 changes: 12 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ message SolverState {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available ID: 28 (last added: accuracy_param)
// LayerParameter next available ID: 31 (last added: relu_param)
message LayerParameter {
repeated string bottom = 2; // the name of the bottom blobs
repeated string top = 3; // the name of the top blobs
Expand Down Expand Up @@ -207,6 +207,7 @@ message LayerParameter {
optional MemoryDataParameter memory_data_param = 22;
optional PoolingParameter pooling_param = 19;
optional PowerParameter power_param = 21;
optional ReLUParameter relu_param = 30;
optional WindowDataParameter window_data_param = 20;
optional ThresholdParameter threshold_param = 25;
optional HingeLossParameter hinge_loss_param = 29;
Expand Down Expand Up @@ -429,6 +430,16 @@ message PowerParameter {
optional float shift = 3 [default = 0.0];
}

// Message that stores parameters used by ReLULayer
message ReLUParameter {
// Allow non-zero slope for negative inputs to speed up optimization
// Described in:
// Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities
// improve neural network acoustic models. In ICML Workshop on Deep Learning
// for Audio, Speech, and Language Processing.
optional float negative_slope = 1 [default = 0];
}

// Message that stores parameters used by WindowDataLayer
message WindowDataParameter {
// Specify the data source.
Expand Down
26 changes: 26 additions & 0 deletions src/caffe/test/test_neuron_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,32 @@ TYPED_TEST(NeuronLayerTest, TestReLUGradient) {
&(this->blob_top_vec_));
}

TYPED_TEST(NeuronLayerTest, TestReLUWithNegativeSlope) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.ParseFromString("relu_param{negative_slope:0.01}");
ReLULayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Now, check values
const Dtype* bottom_data = this->blob_bottom_->cpu_data();
const Dtype* top_data = this->blob_top_->cpu_data();
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_GE(top_data[i], 0.);
EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
}
}

TYPED_TEST(NeuronLayerTest, TestReLUGradientWithNegativeSlope) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.ParseFromString("relu_param{negative_slope:0.01}");
ReLULayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}

TYPED_TEST(NeuronLayerTest, TestSigmoid) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
Expand Down

0 comments on commit dc8bcf3

Please sign in to comment.