Skip to content

Commit

Permalink
replace cuDNN alphas and betas with coefficient values
Browse files Browse the repository at this point in the history
Give cuDNN {0, 1} constants for controlling accumulation through the
alpha and beta coefficients.
  • Loading branch information
shelhamer committed Mar 5, 2015
1 parent 91a6597 commit 2ddbb04
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 85 deletions.
14 changes: 9 additions & 5 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ template <typename Dtype> class dataType;
template<> class dataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
static float oneval, zeroval;
static const void *one, *zero;
};
template<> class dataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
static double oneval, zeroval;
static const void *one, *zero;
};

template <typename Dtype>
Expand Down Expand Up @@ -102,9 +106,9 @@ inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
}

template <typename Dtype>
inline void createPoolingDesc(cudnnPoolingDescriptor_t* conv,
inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
int h, int w, int stride_h, int stride_w) {
int h, int w, int pad_h, int pad_w, int stride_h, int stride_w) {
switch (poolmethod) {
case PoolingParameter_PoolMethod_MAX:
*mode = CUDNN_POOLING_MAX;
Expand All @@ -115,9 +119,9 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* conv,
default:
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(conv));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*conv, *mode, h, w,
0, 0, stride_h, stride_w));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
}

} // namespace cudnn
Expand Down
35 changes: 13 additions & 22 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(

// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
Dtype alpha = 1.0;
Dtype beta = 0.0;

cudnnConvolutionFwdAlgo_t algo;

// get the desired convolution algorithm
Expand Down Expand Up @@ -59,23 +56,21 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(

// Filters.
CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
filter_desc_, weight + weight_offset_ * g,
conv_descs_[i],
algo, workspace, workspaceSizeInBytes,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
top_descs_[i], top_data + top_offset_ * g));

// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
Dtype alpha = 1.0;
Dtype beta = 1.0;
CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
bias_desc_, bias_data + bias_offset_ * g,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::one,
top_descs_[i], top_data + top_offset_ * g));
}
}
Expand Down Expand Up @@ -108,25 +103,22 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
for (int g = 0; g < this->group_; g++) {
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
Dtype alpha = 1.0;
Dtype beta = 1.0;
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g],
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
top_descs_[i], top_diff + top_offset_ * g,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::one,
bias_desc_, bias_diff + bias_offset_ * g));
}

// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = bottom[i]->gpu_data();
Dtype alpha = 1.0;
Dtype beta = 1.0;
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g],
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i], reinterpret_cast<void *>(&beta),
conv_descs_[i],
cudnn::dataType<Dtype>::one,
filter_desc_, weight_diff + weight_offset_ * g));
}

Expand All @@ -136,13 +128,12 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
weight = this->blobs_[0]->gpu_data();
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
Dtype alpha = 1.0;
Dtype beta = 0.0;
CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
filter_desc_, weight + weight_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i], reinterpret_cast<void *>(&beta),
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
cudnn::dataType<Dtype>::zero,
bottom_descs_[i], bottom_diff + bottom_offset_ * g));
}
}
Expand Down
16 changes: 4 additions & 12 deletions src/caffe/layers/cudnn_pooling_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,10 @@ void CuDNNPoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
top_desc_, top_data));
}

Expand All @@ -35,15 +31,11 @@ void CuDNNPoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_data = top[0]->gpu_data();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
top_desc_, top_data, top_desc_, top_diff,
bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
bottom_desc_, bottom_diff));
}

Expand Down
16 changes: 4 additions & 12 deletions src/caffe/layers/cudnn_relu_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@ void CuDNNReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,

const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_RELU,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
}

Expand All @@ -46,16 +42,12 @@ void CuDNNReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_RELU,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
}

Expand Down
16 changes: 4 additions & 12 deletions src/caffe/layers/cudnn_sigmoid_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@ void CuDNNSigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_SIGMOID,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
}

Expand All @@ -36,16 +32,12 @@ void CuDNNSigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_SIGMOID,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
}

Expand Down
15 changes: 4 additions & 11 deletions src/caffe/layers/cudnn_softmax_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@ void CuDNNSoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnSoftmaxForward(handle_, CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
top_desc_, top_data));
}

Expand All @@ -37,14 +33,11 @@ void CuDNNSoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnSoftmaxBackward(handle_, CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
top_desc_, top_data, top_desc_, top_diff,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
bottom_desc_, bottom_diff));
}
}
Expand Down
15 changes: 4 additions & 11 deletions src/caffe/layers/cudnn_tanh_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@ void CuDNNTanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_TANH,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
}

Expand All @@ -37,15 +33,12 @@ void CuDNNTanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();

Dtype alpha = 1.0;
Dtype beta = 0.0;

CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_TANH,
reinterpret_cast<void *>(&alpha),
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
reinterpret_cast<void *>(&beta),
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
}

Expand Down
23 changes: 23 additions & 0 deletions src/caffe/util/cudnn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifdef USE_CUDNN
#include "caffe/util/cudnn.hpp"

namespace caffe {
namespace cudnn {

float dataType<float>::oneval = 1.0;
float dataType<float>::zeroval = 0.0;
const void* dataType<float>::one =
static_cast<void *>(&dataType<float>::oneval);
const void* dataType<float>::zero =
static_cast<void *>(&dataType<float>::zeroval);

double dataType<double>::oneval = 1.0;
double dataType<double>::zeroval = 0.0;
const void* dataType<double>::one =
static_cast<void *>(&dataType<double>::oneval);
const void* dataType<double>::zero =
static_cast<void *>(&dataType<double>::zeroval);

} // namespace cudnn
} // namespace caffe
#endif

0 comments on commit 2ddbb04

Please sign in to comment.