Skip to content

Commit

Permalink
ELU layer with basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mohomran committed Dec 4, 2015
1 parent 7e40583 commit a668194
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 1 deletion.
86 changes: 86 additions & 0 deletions include/caffe/layers/elu_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#ifndef CAFFE_ELU_LAYER_HPP_
#define CAFFE_ELU_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

#include "caffe/layers/neuron_layer.hpp"

namespace caffe {

/**
* @brief Exponential Linear Unit non-linearity @f$
* y = \left\{
* \begin{array}{lr}
* x & \mathrm{if} \; x > 0 \\
* \alpha (\exp(x)-1) & \mathrm{if} \; x \le 0
* \end{array} \right.
* @f$.
*/
template <typename Dtype>
class ELULayer : public NeuronLayer<Dtype> {
public:
/**
* @param param provides ELUParameter elu_param,
* with ELULayer options:
* - alpha (\b optional, default 1).
* the value @f$ \alpha @f$ by which controls saturation for negative inputs.
*/
explicit ELULayer(const LayerParameter& param)
: NeuronLayer<Dtype>(param) {}

virtual inline const char* type() const { return "ELU"; }

protected:
/**
* @param bottom input Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the inputs @f$ x @f$
* @param top output Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the computed outputs @f$
* y = \left\{
* \begin{array}{lr}
* x & \mathrm{if} \; x > 0 \\
* \alpha (\exp(x)-1) & \mathrm{if} \; x \le 0
* \end{array} \right.
* @f$.
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

/**
* @brief Computes the error gradient w.r.t. the ELU inputs.
*
* @param top output Blob vector (length 1), providing the error gradient with
* respect to the outputs
* -# @f$ (N \times C \times H \times W) @f$
* containing error gradients @f$ \frac{\partial E}{\partial y} @f$
* with respect to computed outputs @f$ y @f$
* @param propagate_down see Layer::Backward.
* @param bottom input Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the inputs @f$ x @f$; Backward fills their diff with
* gradients @f$
* \frac{\partial E}{\partial x} = \left\{
* \begin{array}{lr}
* 1 & \mathrm{if} \; x > 0 \\
* y + \alpha & \mathrm{if} \; x \le 0
* \end{array} \right.
* @f$ if propagate_down[0].
*/
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
};


} // namespace caffe

#endif // CAFFE_ELU_LAYER_HPP_
47 changes: 47 additions & 0 deletions src/caffe/layers/elu_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include <algorithm>
#include <vector>

#include "caffe/layers/elu_layer.hpp"

namespace caffe {

template <typename Dtype>
void ELULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const int count = bottom[0]->count();
Dtype alpha = this->layer_param_.elu_param().alpha();
for (int i = 0; i < count; ++i) {
top_data[i] = std::max(bottom_data[i], Dtype(0))
+ alpha * (exp(std::min(bottom_data[i], Dtype(0))) - Dtype(1));
}
}

template <typename Dtype>
void ELULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0]) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* top_data = top[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
const int count = bottom[0]->count();
Dtype alpha = this->layer_param_.elu_param().alpha();
for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * ((bottom_data[i] > 0)
+ (alpha + top_data[i]) * (bottom_data[i] <= 0));
}
}
}


#ifdef CPU_ONLY
STUB_GPU(ELULayer);
#endif

INSTANTIATE_CLASS(ELULayer);
REGISTER_LAYER_CLASS(ELU);

} // namespace caffe
62 changes: 62 additions & 0 deletions src/caffe/layers/elu_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <algorithm>
#include <vector>

#include "caffe/layers/elu_layer.hpp"

namespace caffe {

template <typename Dtype>
__global__ void ELUForward(const int n, const Dtype* in, Dtype* out,
Dtype alpha) {
CUDA_KERNEL_LOOP(index, n) {
out[index] = in[index] > 0 ? in[index] :
alpha * (exp(in[index]) - 1);
}
}

template <typename Dtype>
void ELULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const int count = bottom[0]->count();
Dtype alpha = this->layer_param_.elu_param().alpha();
// NOLINT_NEXT_LINE(whitespace/operators)
ELUForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data, alpha);
CUDA_POST_KERNEL_CHECK;
}

template <typename Dtype>
__global__ void ELUBackward(const int n, const Dtype* in_diff,
const Dtype* out_data, const Dtype* in_data,
Dtype* out_diff, Dtype alpha) {
CUDA_KERNEL_LOOP(index, n) {
out_diff[index] = in_data[index] > 0 ? in_diff[index] :
in_diff[index] * (out_data[index] + alpha);
}
}

template <typename Dtype>
void ELULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0]) {
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* top_data = top[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const int count = bottom[0]->count();
Dtype alpha = this->layer_param_.elu_param().alpha();
// NOLINT_NEXT_LINE(whitespace/operators)
ELUBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, top_data, bottom_data, bottom_diff, alpha);
CUDA_POST_KERNEL_CHECK;
}
}


INSTANTIATE_LAYER_GPU_FUNCS(ELULayer);


} // namespace caffe
11 changes: 10 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 140 (last added: batch_norm_param)
// LayerParameter next available layer-specific ID: 141 (last added: elu_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
Expand Down Expand Up @@ -363,6 +363,7 @@ message LayerParameter {
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
optional ELUParameter elu_param = 140;
optional EmbedParameter embed_param = 137;
optional ExpParameter exp_param = 111;
optional FlattenParameter flatten_param = 135;
Expand Down Expand Up @@ -629,6 +630,14 @@ message EltwiseParameter {
optional bool stable_prod_grad = 3 [default = true];
}

// Message that stores parameters used by ELULayer
message ELUParameter {
// Described in:
// Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate
// Deep Network Learning by Exponential Linear Units (ELUs). arXiv
optional float alpha = 1 [default = 1];
}

// Message that stores parameters used by EmbedLayer
message EmbedParameter {
optional uint32 num_output = 1; // The number of outputs for the layer
Expand Down
59 changes: 59 additions & 0 deletions src/caffe/test/test_neuron_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "caffe/layers/absval_layer.hpp"
#include "caffe/layers/bnll_layer.hpp"
#include "caffe/layers/dropout_layer.hpp"
#include "caffe/layers/elu_layer.hpp"
#include "caffe/layers/exp_layer.hpp"
#include "caffe/layers/inner_product_layer.hpp"
#include "caffe/layers/log_layer.hpp"
Expand Down Expand Up @@ -259,6 +260,64 @@ TYPED_TEST(NeuronLayerTest, TestReLUGradientWithNegativeSlope) {
this->blob_top_vec_);
}

TYPED_TEST(NeuronLayerTest, TestELU) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
CHECK(google::protobuf::TextFormat::ParseFromString(
"elu_param { alpha: 0.5 }", &layer_param));
ELULayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
const Dtype kDelta = 2e-4;
// Now, check values
const Dtype* bottom_data = this->blob_bottom_->cpu_data();
const Dtype* top_data = this->blob_top_->cpu_data();
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
if (bottom_data[i] > 0) {
EXPECT_FLOAT_EQ(top_data[i], bottom_data[i]);
} else {
EXPECT_NEAR(top_data[i], 0.5 * (exp(bottom_data[i]) - 1), kDelta);
}
}
}

TYPED_TEST(NeuronLayerTest, TestELUasReLU) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
CHECK(google::protobuf::TextFormat::ParseFromString(
"elu_param { alpha: 0 }", &layer_param));
ELULayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
// Now, check values
const Dtype* bottom_data = this->blob_bottom_->cpu_data();
const Dtype* top_data = this->blob_top_->cpu_data();
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_GE(top_data[i], 0.);
EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
}
}

TYPED_TEST(NeuronLayerTest, TestELUGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
ELULayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
this->blob_top_vec_);
}

TYPED_TEST(NeuronLayerTest, TestELUasReLUGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
CHECK(google::protobuf::TextFormat::ParseFromString(
"elu_param { alpha: 0 }", &layer_param));
ELULayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
this->blob_top_vec_);
}

TYPED_TEST(NeuronLayerTest, TestSigmoid) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
Expand Down

0 comments on commit a668194

Please sign in to comment.