Skip to content

Commit

Permalink
Add cuDNN v5 support, drop cuDNN v3 support
Browse files Browse the repository at this point in the history
cuDNN v4 is still supported.
  • Loading branch information
flx42 committed May 16, 2016
1 parent bb0c1a5 commit b43c8e4
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 19 deletions.
1 change: 1 addition & 0 deletions include/caffe/layers/cudnn_relu_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activ_desc_;
};
#endif

Expand Down
1 change: 1 addition & 0 deletions include/caffe/layers/cudnn_sigmoid_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activ_desc_;
};
#endif

Expand Down
1 change: 1 addition & 0 deletions include/caffe/layers/cudnn_tanh_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activ_desc_;
};
#endif

Expand Down
24 changes: 21 additions & 3 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,13 @@ template <typename Dtype>
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
n, c, h, w));
CUDNN_TENSOR_NCHW, n, c, h, w));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType<Dtype>::type,
CUDNN_TENSOR_NCHW, n, c, h, w));
#endif
}

template <typename Dtype>
Expand Down Expand Up @@ -123,8 +128,21 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
#endif
}

template <typename Dtype>
inline void createActivationDescriptor(cudnnActivationDescriptor_t* activ_desc,
cudnnActivationMode_t mode) {
CUDNN_CHECK(cudnnCreateActivationDescriptor(activ_desc));
CUDNN_CHECK(cudnnSetActivationDescriptor(*activ_desc, mode,
CUDNN_PROPAGATE_NAN, Dtype(0)));
}

} // namespace cudnn
Expand Down
12 changes: 2 additions & 10 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,11 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
#if CUDNN_VERSION_MIN(4, 0, 0)
CUDNN_CHECK(cudnnAddTensor(handle_[g],
cudnn::dataType<Dtype>::one,
bias_desc_, bias_data + bias_offset_ * g,
cudnn::dataType<Dtype>::one,
top_descs_[i], top_data + top_offset_ * g));
#else
CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C,
cudnn::dataType<Dtype>::one,
bias_desc_, bias_data + bias_offset_ * g,
cudnn::dataType<Dtype>::one,
top_descs_[i], top_data + top_offset_ * g));
#endif
}
}

Expand Down Expand Up @@ -82,7 +74,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = bottom[i]->gpu_data();
CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
CUDNN_CHECK(cudnnConvolutionBackwardFilter(
handle_[1*this->group_ + g],
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
Expand All @@ -100,7 +92,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
weight = this->blobs_[0]->gpu_data();
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
CUDNN_CHECK(cudnnConvolutionBackwardData(
handle_[2*this->group_ + g],
cudnn::dataType<Dtype>::one,
filter_desc_, weight + this->weight_offset_ * g,
Expand Down
1 change: 1 addition & 0 deletions src/caffe/layers/cudnn_relu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createActivationDescriptor<Dtype>(&activ_desc_, CUDNN_ACTIVATION_RELU);
handles_setup_ = true;
}

Expand Down
23 changes: 21 additions & 2 deletions src/caffe/layers/cudnn_relu_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@ void CuDNNReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,

const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_RELU,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward_v4(this->handle_,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}

template <typename Dtype>
Expand All @@ -40,13 +49,23 @@ void CuDNNReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_RELU,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(this->handle_,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer);
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layers/cudnn_sigmoid_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ void CuDNNSigmoidLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createActivationDescriptor<Dtype>(&activ_desc_,
CUDNN_ACTIVATION_SIGMOID);
handles_setup_ = true;
}

Expand Down
23 changes: 21 additions & 2 deletions src/caffe/layers/cudnn_sigmoid_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ void CuDNNSigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_SIGMOID,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward_v4(this->handle_,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}

template <typename Dtype>
Expand All @@ -30,13 +39,23 @@ void CuDNNSigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_SIGMOID,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(this->handle_,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSigmoidLayer);
Expand Down
1 change: 1 addition & 0 deletions src/caffe/layers/cudnn_tanh_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ void CuDNNTanHLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createActivationDescriptor<Dtype>(&activ_desc_, CUDNN_ACTIVATION_TANH);
handles_setup_ = true;
}

Expand Down
23 changes: 21 additions & 2 deletions src/caffe/layers/cudnn_tanh_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ void CuDNNTanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_TANH,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward_v4(this->handle_,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}

template <typename Dtype>
Expand All @@ -31,13 +40,23 @@ void CuDNNTanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();

#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_TANH,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(this->handle_,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNTanHLayer);
Expand Down

2 comments on commit b43c8e4

@lld533
Copy link

@lld533 lld533 commented on b43c8e4 Sep 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Line ~40 of cudnn_{tanh, sigmoid, relu}layer.cpp, to prevent possible memory leak, it's better to destroy activation descriptor by calling "cudnnDestroyActivationDescriptor(this->activ_desc)"; before destroy the handle_. Much thanks.

@flx42
Copy link
Contributor Author

@flx42 flx42 commented on b43c8e4 Sep 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lld533: you should submit a PR then.

Please sign in to comment.