From 947f3ca06412de328bced7aeea597343920669f4 Mon Sep 17 00:00:00 2001 From: sp2823 Date: Fri, 2 Jun 2017 14:27:27 +0800 Subject: [PATCH 1/9] add ConvolutionDepthwise layer --- include/caffe/layers/conv_dw_layer.hpp | 48 ++++ src/caffe/layers/conv_dw_layer.cpp | 315 +++++++++++++++++++++++++ src/caffe/layers/conv_dw_layer.cu | 212 +++++++++++++++++ 3 files changed, 575 insertions(+) create mode 100644 include/caffe/layers/conv_dw_layer.hpp create mode 100644 src/caffe/layers/conv_dw_layer.cpp create mode 100644 src/caffe/layers/conv_dw_layer.cu diff --git a/include/caffe/layers/conv_dw_layer.hpp b/include/caffe/layers/conv_dw_layer.hpp new file mode 100644 index 00000000000..4ed51d6ed91 --- /dev/null +++ b/include/caffe/layers/conv_dw_layer.hpp @@ -0,0 +1,48 @@ +#ifndef CAFFE_CONV_DW_LAYER_HPP_ +#define CAFFE_CONV_DW_LAYER_HPP_ + +#include +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +template +class ConvolutionDepthwiseLayer : public Layer { + public: + explicit ConvolutionDepthwiseLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline const char* type() const { return "ConvolutionDepthwise"; } + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + unsigned int kernel_h_; + unsigned int kernel_w_; + unsigned int stride_h_; + unsigned int stride_w_; + unsigned int pad_h_; + unsigned int pad_w_; + unsigned int dilation_h_; + unsigned int dilation_w_; + Blob weight_buffer_; + Blob weight_multiplier_; + Blob bias_buffer_; + Blob bias_multiplier_; +}; + +} // namespace caffe + +#endif // CAFFE_CONV_DW_LAYER_HPP_ diff --git a/src/caffe/layers/conv_dw_layer.cpp b/src/caffe/layers/conv_dw_layer.cpp new file mode 100644 index 00000000000..1232f1693a6 --- /dev/null +++ b/src/caffe/layers/conv_dw_layer.cpp @@ -0,0 +1,315 @@ +#include +#include +#include "caffe/filler.hpp" +#include "caffe/layers/conv_dw_layer.hpp" + +namespace caffe { + +template +void ConvolutionDepthwiseLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + ConvolutionParameter conv_param = this->layer_param_.convolution_param(); + if (conv_param.has_kernel_h() && conv_param.has_kernel_w()) { + kernel_h_ = conv_param.kernel_h(); + kernel_w_ = conv_param.kernel_w(); + } else { + if (conv_param.kernel_size_size() == 1) + { + kernel_h_ = conv_param.kernel_size(0); + kernel_w_ = conv_param.kernel_size(0); + } + else + { + kernel_h_ = conv_param.kernel_size(0); + kernel_w_ = conv_param.kernel_size(1); + } + } + if (conv_param.has_stride_h() && conv_param.has_stride_w()) { + stride_h_ = conv_param.stride_h(); + stride_w_ = conv_param.stride_w(); + } else { + if (conv_param.stride_size() == 1) + { + stride_h_ = conv_param.stride(0); + stride_w_ = conv_param.stride(0); + } + else + { + stride_h_ = conv_param.stride(0); + stride_w_ = conv_param.stride(1); + } + } + if (conv_param.has_pad_h() && conv_param.has_pad_w()) { + pad_h_ = conv_param.pad_h(); + pad_w_ = conv_param.pad_w(); + } else { + if (conv_param.pad_size() == 1) + { + pad_h_ = conv_param.pad(0); + pad_w_ = conv_param.pad(0); + } + else + { + pad_h_ = conv_param.pad(0); + pad_w_ = conv_param.pad(1); + } + } + if (conv_param.dilation_size() > 0) + { + if (conv_param.dilation_size() == 1) + { + dilation_h_ = conv_param.dilation(0); + dilation_w_ = conv_param.dilation(0); + } + else + { + dilation_h_ = conv_param.dilation(0); + dilation_w_ = conv_param.dilation(1); + } + } + else + { + dilation_h_ = 1; + dilation_w_ = 1; + } + vector weight_shape(4); + weight_shape[0] = bottom[0]->channels(); + weight_shape[1] = 1; + weight_shape[2] = kernel_h_; + weight_shape[3] = kernel_w_; + vector bias_shape; + if (conv_param.bias_term()) + { + bias_shape.push_back(bottom[0]->channels()); + } + if (this->blobs_.size() == 0) { + if (conv_param.bias_term()) { + this->blobs_.resize(2); + } else { + this->blobs_.resize(1); + } + this->blobs_[0].reset(new Blob(weight_shape)); + shared_ptr > weight_filler(GetFiller(conv_param.weight_filler())); + weight_filler->Fill(this->blobs_[0].get()); + if (conv_param.bias_term()) { + this->blobs_[1].reset(new Blob(bias_shape)); + shared_ptr > bias_filler(GetFiller(conv_param.bias_filler())); + bias_filler->Fill(this->blobs_[1].get()); + } + } + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template +void ConvolutionDepthwiseLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + vector top_shape; + top_shape.push_back(bottom[0]->num()); + top_shape.push_back(bottom[0]->channels()); + top_shape.push_back((bottom[0]->height() + 2 * pad_h_ - (dilation_h_ * (kernel_h_ - 1) + 1)) / stride_h_ + 1); + top_shape.push_back((bottom[0]->width() + 2 * pad_w_ - (dilation_w_ * (kernel_w_ - 1) + 1)) / stride_w_ + 1); + top[0]->Reshape(top_shape); + vector weight_buffer_shape; + weight_buffer_shape.push_back(bottom[0]->channels()); + weight_buffer_shape.push_back(kernel_h_); + weight_buffer_shape.push_back(kernel_w_); + weight_buffer_shape.push_back(bottom[0]->num()); + weight_buffer_shape.push_back(top[0]->height()); + weight_buffer_shape.push_back(top[0]->width()); + weight_buffer_.Reshape(weight_buffer_shape); + vector weight_multiplier_shape; + weight_multiplier_shape.push_back(bottom[0]->num()); + weight_multiplier_shape.push_back(top[0]->height()); + weight_multiplier_shape.push_back(top[0]->width()); + weight_multiplier_.Reshape(weight_multiplier_shape); + caffe_gpu_set(weight_multiplier_.count(), Dtype(1), weight_multiplier_.mutable_gpu_data()); + if (this->layer_param_.convolution_param().bias_term()) + { + vector bias_buffer_shape; + bias_buffer_shape.push_back(bottom[0]->channels()); + bias_buffer_shape.push_back(bottom[0]->num()); + bias_buffer_shape.push_back(top[0]->height()); + bias_buffer_shape.push_back(top[0]->width()); + bias_buffer_.Reshape(bias_buffer_shape); + vector bias_multiplier_shape; + bias_multiplier_shape.push_back(bottom[0]->num()); + bias_multiplier_shape.push_back(top[0]->height()); + bias_multiplier_shape.push_back(top[0]->width()); + bias_multiplier_.Reshape(bias_multiplier_shape); + caffe_gpu_set(bias_multiplier_.count(), Dtype(1), bias_multiplier_.mutable_gpu_data()); + } +} + +template +void ConvolutionDepthwiseLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) +{ + const int num = top[0]->num(); + const int channels = top[0]->channels(); + const int top_height = top[0]->height(); + const int top_width = top[0]->width(); + const int bottom_height = bottom[0]->height(); + const int bottom_width = bottom[0]->width(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* weight_data_base = this->blobs_[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + for (int n = 0; n < num; ++n) + { + for (int c = 0; c < channels; ++c) + { + for (int h = 0; h < top_height; ++h) + { + for (int w = 0; w < top_width; ++w) + { + const Dtype* weight_data = weight_data_base + c * kernel_h_ * kernel_w_; + Dtype value = 0; + for (int kh = 0; kh < kernel_h_; ++kh) + { + for (int kw = 0; kw < kernel_w_; ++kw) + { + int h_in = -pad_h_ + h * stride_h_ + kh * dilation_h_; + int w_in = -pad_w_ + w * stride_w_ + kw * dilation_w_; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) + { + int offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + value += (*weight_data) * bottom_data[offset]; + } + ++weight_data; + } + } + *top_data++ = value; + } + } + } + } + if (this->layer_param_.convolution_param().bias_term()) + { + top_data = top[0]->mutable_cpu_data(); + for (int n = 0; n < num; ++n) + { + const Dtype* bias_data = this->blobs_[1]->cpu_data(); + for (int c = 0; c < channels; ++c) + { + for (int h = 0; h < top_height; ++h) + { + for (int w = 0; w < top_width; ++w) + { + *top_data += *bias_data; + ++top_data; + } + } + ++bias_data; + } + } + } +} + +template +void ConvolutionDepthwiseLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) +{ + const int num = top[0]->num(); + const int channels = top[0]->channels(); + const int top_height = top[0]->height(); + const int top_width = top[0]->width(); + const int bottom_height = bottom[0]->height(); + const int bottom_width = bottom[0]->width(); + caffe_set(bottom[0]->count(), Dtype(0), bottom[0]->mutable_cpu_diff()); + if (this->layer_param_.convolution_param().bias_term() && this->param_propagate_down_[1]) + { + const Dtype* top_diff = top[0]->cpu_diff(); + for (int n = 0; n < num; ++n) + { + Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); + for (int c = 0; c < channels; ++c) + { + for (int h = 0; h < top_height; ++h) + { + for (int w = 0; w < top_width; ++w) + { + *bias_diff += *top_diff; + ++top_diff; + } + } + ++bias_diff; + } + } + } + if (this->param_propagate_down_[0]) + { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* weight_diff_base = this->blobs_[0]->mutable_cpu_diff(); + for (int n = 0; n < num; ++n) + { + for (int c = 0; c < channels; ++c) + { + for (int h = 0; h < top_height; ++h) + { + for (int w = 0; w < top_width; ++w) + { + Dtype* weight_diff = weight_diff_base + c * kernel_h_ * kernel_w_; + for (int kh = 0; kh < kernel_h_; ++kh) + { + for (int kw = 0; kw < kernel_w_; ++kw) + { + int h_in = -pad_h_ + h * stride_h_ + kh * dilation_h_; + int w_in = -pad_w_ + w * stride_w_ + kw * dilation_w_; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) + { + int offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + *weight_diff += bottom_data[offset] * (*top_diff); + } + ++weight_diff; + } + } + ++top_diff; + } + } + } + } + } + if (propagate_down[0]) + { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* weight_data_base = this->blobs_[0]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + for (int n = 0; n < num; ++n) + { + for (int c = 0; c < channels; ++c) + { + for (int h = 0; h < top_height; ++h) + { + for (int w = 0; w < top_width; ++w) + { + const Dtype* weight_data = weight_data_base + c * kernel_h_ * kernel_w_; + for (int kh = 0; kh < kernel_h_; ++kh) + { + for (int kw = 0; kw < kernel_w_; ++kw) + { + int h_in = -pad_h_ + h * stride_h_ + kh * dilation_h_; + int w_in = -pad_w_ + w * stride_w_ + kw * dilation_w_; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) + { + int offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + bottom_diff[offset] += (*weight_data) * (*top_diff); + } + ++weight_data; + } + } + ++top_diff; + } + } + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(ConvolutionDepthwiseLayer); +#endif + +INSTANTIATE_CLASS(ConvolutionDepthwiseLayer); +REGISTER_LAYER_CLASS(ConvolutionDepthwise); + +} // namespace caffe diff --git a/src/caffe/layers/conv_dw_layer.cu b/src/caffe/layers/conv_dw_layer.cu new file mode 100644 index 00000000000..ed617495070 --- /dev/null +++ b/src/caffe/layers/conv_dw_layer.cu @@ -0,0 +1,212 @@ +#include +#include "caffe/layers/conv_dw_layer.hpp" +#include "caffe/util/gpu_util.cuh" + +namespace caffe { + +template +__global__ void ConvolutionDepthwiseWeightForward(const int nthreads, + const Dtype* const bottom_data, const Dtype* const weight_data, const int num, const int channels, + const int top_height, const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, + Dtype* const top_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / channels / top_height / top_width; + const int c = (index / top_height / top_width) % channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const Dtype* weight = weight_data + c * kernel_h * kernel_w; + Dtype value = 0; + for (int kh = 0; kh < kernel_h; ++kh) + { + for (int kw = 0; kw < kernel_w; ++kw) + { + const int h_in = -pad_h + h * stride_h + kh * dilation_h; + const int w_in = -pad_w + w * stride_w + kw * dilation_w; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) + { + const int offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + value += (*weight) * bottom_data[offset]; + } + ++weight; + } + } + top_data[index] = value; + } +} + +template +__global__ void ConvolutionDepthwiseBiasForward(const int nthreads, + const Dtype* const bias_data, const int num, const int channels, + const int top_height, const int top_width, Dtype* const top_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int c = (index / top_height / top_width) % channels; + top_data[index] += bias_data[c]; + } +} + +template +void ConvolutionDepthwiseLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const Dtype* weight_data = this->blobs_[0]->gpu_data(); + const int count = top[0]->count(); + const int num = top[0]->num(); + const int channels = top[0]->channels(); + const int top_height = top[0]->height(); + const int top_width = top[0]->width(); + const int bottom_height = bottom[0]->height(); + const int bottom_width = bottom[0]->width(); + ConvolutionDepthwiseWeightForward<<>>( + count, bottom_data, weight_data, num, channels, + top_height, top_width, bottom_height, bottom_width, + kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_h_, pad_w_, dilation_h_, dilation_w_, top_data); + if (this->layer_param_.convolution_param().bias_term()) + { + const Dtype* bias_data = this->blobs_[1]->gpu_data(); + ConvolutionDepthwiseBiasForward<<>>( + count, bias_data, num, channels, + top_height, top_width, top_data); + } +} + +template +__global__ void ConvolutionDepthwiseWeightBackward(const int nthreads, + const Dtype* const top_diff, const Dtype* const bottom_data, const int num, const int channels, + const int top_height, const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, + Dtype* const buffer_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const int kh = (index / kernel_w / num / top_height / top_width) % kernel_h; + const int kw = (index / num / top_height / top_width) % kernel_w; + const int h_in = -pad_h + h * stride_h + kh * dilation_h; + const int w_in = -pad_w + w * stride_w + kw * dilation_w; + if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) + { + const int c = index / kernel_h / kernel_w / num / top_height / top_width; + const int n = (index / top_height / top_width) % num; + const int top_offset = ((n * channels + c) * top_height + h) * top_width + w; + const int bottom_offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + buffer_data[index] = top_diff[top_offset] * bottom_data[bottom_offset]; + } + else + { + buffer_data[index] = 0; + } + } +} + +template +__global__ void ConvolutionDepthwiseBottomBackward(const int nthreads, + const Dtype* const top_diff, const Dtype* const weight_data, const int num, const int channels, + const int top_height, const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, + Dtype* const bottom_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / channels / bottom_height / bottom_width; + const int c = (index / bottom_height / bottom_width) % channels; + const int h = (index / bottom_width) % bottom_height; + const int w = index % bottom_width; + const Dtype* weight = weight_data + c * kernel_h * kernel_w; + Dtype value = 0; + for (int kh = 0; kh < kernel_h; ++kh) + { + for (int kw = 0; kw < kernel_w; ++kw) + { + const int h_out_s = h + pad_h - kh * dilation_h; + const int w_out_s = w + pad_w - kw * dilation_w; + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) + { + const int h_out = h_out_s / stride_h; + const int w_out = w_out_s / stride_w; + if ((h_out >= 0) && (h_out < top_height) && (w_out >= 0) && (w_out < top_width)) + { + const int offset = ((n * channels + c) * top_height + h_out) * top_width + w_out; + value += (*weight) * top_diff[offset]; + } + } + ++weight; + } + } + bottom_diff[index] += value; + } +} + +template +__global__ void ConvolutionDepthwiseBiasBackward(const int nthreads, + const Dtype* const top_diff, const int num, const int channels, + const int top_height, const int top_width, Dtype* const buffer_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int c = index / num / top_height / top_width; + const int n = (index / top_height / top_width) % num; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const int offset = ((n * channels + c) * top_height + h) * top_width + w; + buffer_data[index] = top_diff[offset]; + } +} + +template +void ConvolutionDepthwiseLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + const int bottom_count = bottom[0]->count(); + const int num = top[0]->num(); + const int channels = top[0]->channels(); + const int top_height = top[0]->height(); + const int top_width = top[0]->width(); + const int bottom_height = bottom[0]->height(); + const int bottom_width = bottom[0]->width(); + const int length = num * top_height * top_width; + caffe_gpu_set(bottom_count, Dtype(0), bottom[0]->mutable_gpu_diff()); + if (this->layer_param_.convolution_param().bias_term() && this->param_propagate_down_[1]) + { + const int bias_buffer_count = bias_buffer_.count(); + Dtype* bias_buffer_mutable_data = bias_buffer_.mutable_gpu_data(); + ConvolutionDepthwiseBiasBackward<<>>( + bias_buffer_count, top_diff, num, channels, + top_height, top_width, bias_buffer_mutable_data); + const int bias_count = this->blobs_[1]->count(); + const Dtype* bias_buffer_data = bias_buffer_.gpu_data(); + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + const Dtype* bias_multiplier_data = bias_multiplier_.gpu_data(); + caffe_gpu_gemv(CblasNoTrans, bias_count, length, Dtype(1), bias_buffer_data, bias_multiplier_data, Dtype(1), bias_diff); + } + if (this->param_propagate_down_[0]) + { + const int weight_buffer_count = weight_buffer_.count(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* weight_buffer_mutable_data = weight_buffer_.mutable_gpu_data(); + ConvolutionDepthwiseWeightBackward<<>>( + weight_buffer_count, top_diff, bottom_data, num, channels, + top_height, top_width, bottom_height, bottom_width, + kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_h_, pad_w_, dilation_h_, dilation_w_, weight_buffer_mutable_data); + const int weight_count = this->blobs_[0]->count(); + const Dtype* weight_buffer_data = weight_buffer_.gpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); + const Dtype* weight_multiplier_data = weight_multiplier_.gpu_data(); + caffe_gpu_gemv(CblasNoTrans, weight_count, length, Dtype(1), weight_buffer_data, weight_multiplier_data, Dtype(1), weight_diff); + } + if (propagate_down[0]) + { + const Dtype* weight_data = this->blobs_[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + ConvolutionDepthwiseBottomBackward<<>>( + bottom_count, top_diff, weight_data, num, channels, + top_height, top_width, bottom_height, bottom_width, + kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_h_, pad_w_, dilation_h_, dilation_w_, bottom_diff); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(ConvolutionDepthwiseLayer); + +} // namespace caffe From 978b11feaedc75a9ed52147f316815deb6f3eb30 Mon Sep 17 00:00:00 2001 From: sp2823 Date: Fri, 2 Jun 2017 14:44:57 +0800 Subject: [PATCH 2/9] add ConvolutionDepthwise layer --- src/caffe/layers/conv_dw_layer.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/caffe/layers/conv_dw_layer.cpp b/src/caffe/layers/conv_dw_layer.cpp index 1232f1693a6..39dda270c38 100644 --- a/src/caffe/layers/conv_dw_layer.cpp +++ b/src/caffe/layers/conv_dw_layer.cpp @@ -1,6 +1,7 @@ #include #include #include "caffe/filler.hpp" +#include "caffe/util/math_functions.hpp" #include "caffe/layers/conv_dw_layer.hpp" namespace caffe { @@ -122,7 +123,7 @@ void ConvolutionDepthwiseLayer::Reshape(const vector*>& botto weight_multiplier_shape.push_back(top[0]->height()); weight_multiplier_shape.push_back(top[0]->width()); weight_multiplier_.Reshape(weight_multiplier_shape); - caffe_gpu_set(weight_multiplier_.count(), Dtype(1), weight_multiplier_.mutable_gpu_data()); + caffe_set(weight_multiplier_.count(), Dtype(1), weight_multiplier_.mutable_gpu_data()); if (this->layer_param_.convolution_param().bias_term()) { vector bias_buffer_shape; @@ -136,7 +137,7 @@ void ConvolutionDepthwiseLayer::Reshape(const vector*>& botto bias_multiplier_shape.push_back(top[0]->height()); bias_multiplier_shape.push_back(top[0]->width()); bias_multiplier_.Reshape(bias_multiplier_shape); - caffe_gpu_set(bias_multiplier_.count(), Dtype(1), bias_multiplier_.mutable_gpu_data()); + caffe_set(bias_multiplier_.count(), Dtype(1), bias_multiplier_.mutable_gpu_data()); } } From 15191abdf6e0f8b0e217d0843764cd1bfe1b6ce7 Mon Sep 17 00:00:00 2001 From: sp2823 Date: Fri, 2 Jun 2017 15:47:19 +0800 Subject: [PATCH 3/9] satisfy the code format of caffe --- include/caffe/layers/conv_dw_layer.hpp | 1 + src/caffe/layers/conv_dw_layer.cpp | 196 ++++++++++--------------- src/caffe/layers/conv_dw_layer.cu | 113 +++++++------- 3 files changed, 139 insertions(+), 171 deletions(-) diff --git a/include/caffe/layers/conv_dw_layer.hpp b/include/caffe/layers/conv_dw_layer.hpp index 4ed51d6ed91..8b5133ada24 100644 --- a/include/caffe/layers/conv_dw_layer.hpp +++ b/include/caffe/layers/conv_dw_layer.hpp @@ -20,6 +20,7 @@ class ConvolutionDepthwiseLayer : public Layer { virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int ExactNumTopBlobs() const { return 1; } virtual inline const char* type() const { return "ConvolutionDepthwise"; } + protected: virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); diff --git a/src/caffe/layers/conv_dw_layer.cpp b/src/caffe/layers/conv_dw_layer.cpp index 39dda270c38..2cc742f8de5 100644 --- a/src/caffe/layers/conv_dw_layer.cpp +++ b/src/caffe/layers/conv_dw_layer.cpp @@ -7,20 +7,17 @@ namespace caffe { template -void ConvolutionDepthwiseLayer::LayerSetUp(const vector*>& bottom, - const vector*>& top) { +void ConvolutionDepthwiseLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { ConvolutionParameter conv_param = this->layer_param_.convolution_param(); if (conv_param.has_kernel_h() && conv_param.has_kernel_w()) { kernel_h_ = conv_param.kernel_h(); kernel_w_ = conv_param.kernel_w(); } else { - if (conv_param.kernel_size_size() == 1) - { + if (conv_param.kernel_size_size() == 1) { kernel_h_ = conv_param.kernel_size(0); kernel_w_ = conv_param.kernel_size(0); - } - else - { + } else { kernel_h_ = conv_param.kernel_size(0); kernel_w_ = conv_param.kernel_size(1); } @@ -29,13 +26,10 @@ void ConvolutionDepthwiseLayer::LayerSetUp(const vector*>& bo stride_h_ = conv_param.stride_h(); stride_w_ = conv_param.stride_w(); } else { - if (conv_param.stride_size() == 1) - { + if (conv_param.stride_size() == 1) { stride_h_ = conv_param.stride(0); stride_w_ = conv_param.stride(0); - } - else - { + } else { stride_h_ = conv_param.stride(0); stride_w_ = conv_param.stride(1); } @@ -44,32 +38,23 @@ void ConvolutionDepthwiseLayer::LayerSetUp(const vector*>& bo pad_h_ = conv_param.pad_h(); pad_w_ = conv_param.pad_w(); } else { - if (conv_param.pad_size() == 1) - { + if (conv_param.pad_size() == 1) { pad_h_ = conv_param.pad(0); pad_w_ = conv_param.pad(0); - } - else - { + } else { pad_h_ = conv_param.pad(0); pad_w_ = conv_param.pad(1); } } - if (conv_param.dilation_size() > 0) - { - if (conv_param.dilation_size() == 1) - { + if (conv_param.dilation_size() > 0) { + if (conv_param.dilation_size() == 1) { dilation_h_ = conv_param.dilation(0); dilation_w_ = conv_param.dilation(0); - } - else - { + } else { dilation_h_ = conv_param.dilation(0); dilation_w_ = conv_param.dilation(1); } - } - else - { + } else { dilation_h_ = 1; dilation_w_ = 1; } @@ -79,8 +64,7 @@ void ConvolutionDepthwiseLayer::LayerSetUp(const vector*>& bo weight_shape[2] = kernel_h_; weight_shape[3] = kernel_w_; vector bias_shape; - if (conv_param.bias_term()) - { + if (conv_param.bias_term()) { bias_shape.push_back(bottom[0]->channels()); } if (this->blobs_.size() == 0) { @@ -90,11 +74,13 @@ void ConvolutionDepthwiseLayer::LayerSetUp(const vector*>& bo this->blobs_.resize(1); } this->blobs_[0].reset(new Blob(weight_shape)); - shared_ptr > weight_filler(GetFiller(conv_param.weight_filler())); + shared_ptr > weight_filler(GetFiller( + conv_param.weight_filler())); weight_filler->Fill(this->blobs_[0].get()); if (conv_param.bias_term()) { this->blobs_[1].reset(new Blob(bias_shape)); - shared_ptr > bias_filler(GetFiller(conv_param.bias_filler())); + shared_ptr > bias_filler(GetFiller( + conv_param.bias_filler())); bias_filler->Fill(this->blobs_[1].get()); } } @@ -102,13 +88,15 @@ void ConvolutionDepthwiseLayer::LayerSetUp(const vector*>& bo } template -void ConvolutionDepthwiseLayer::Reshape(const vector*>& bottom, - const vector*>& top) { +void ConvolutionDepthwiseLayer::Reshape( + const vector*>& bottom, const vector*>& top) { vector top_shape; top_shape.push_back(bottom[0]->num()); top_shape.push_back(bottom[0]->channels()); - top_shape.push_back((bottom[0]->height() + 2 * pad_h_ - (dilation_h_ * (kernel_h_ - 1) + 1)) / stride_h_ + 1); - top_shape.push_back((bottom[0]->width() + 2 * pad_w_ - (dilation_w_ * (kernel_w_ - 1) + 1)) / stride_w_ + 1); + top_shape.push_back((bottom[0]->height() + 2 * pad_h_ + - (dilation_h_ * (kernel_h_ - 1) + 1)) / stride_h_ + 1); + top_shape.push_back((bottom[0]->width() + 2 * pad_w_ + - (dilation_w_ * (kernel_w_ - 1) + 1)) / stride_w_ + 1); top[0]->Reshape(top_shape); vector weight_buffer_shape; weight_buffer_shape.push_back(bottom[0]->channels()); @@ -123,9 +111,9 @@ void ConvolutionDepthwiseLayer::Reshape(const vector*>& botto weight_multiplier_shape.push_back(top[0]->height()); weight_multiplier_shape.push_back(top[0]->width()); weight_multiplier_.Reshape(weight_multiplier_shape); - caffe_set(weight_multiplier_.count(), Dtype(1), weight_multiplier_.mutable_gpu_data()); - if (this->layer_param_.convolution_param().bias_term()) - { + caffe_set(weight_multiplier_.count(), Dtype(1), + weight_multiplier_.mutable_gpu_data()); + if (this->layer_param_.convolution_param().bias_term()) { vector bias_buffer_shape; bias_buffer_shape.push_back(bottom[0]->channels()); bias_buffer_shape.push_back(bottom[0]->num()); @@ -137,14 +125,14 @@ void ConvolutionDepthwiseLayer::Reshape(const vector*>& botto bias_multiplier_shape.push_back(top[0]->height()); bias_multiplier_shape.push_back(top[0]->width()); bias_multiplier_.Reshape(bias_multiplier_shape); - caffe_set(bias_multiplier_.count(), Dtype(1), bias_multiplier_.mutable_gpu_data()); + caffe_set(bias_multiplier_.count(), Dtype(1), + bias_multiplier_.mutable_gpu_data()); } } template -void ConvolutionDepthwiseLayer::Forward_cpu(const vector*>& bottom, - const vector*>& top) -{ +void ConvolutionDepthwiseLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { const int num = top[0]->num(); const int channels = top[0]->channels(); const int top_height = top[0]->height(); @@ -154,25 +142,21 @@ void ConvolutionDepthwiseLayer::Forward_cpu(const vector*>& b const Dtype* bottom_data = bottom[0]->cpu_data(); const Dtype* weight_data_base = this->blobs_[0]->cpu_data(); Dtype* top_data = top[0]->mutable_cpu_data(); - for (int n = 0; n < num; ++n) - { - for (int c = 0; c < channels; ++c) - { - for (int h = 0; h < top_height; ++h) - { - for (int w = 0; w < top_width; ++w) - { - const Dtype* weight_data = weight_data_base + c * kernel_h_ * kernel_w_; + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < top_height; ++h) { + for (int w = 0; w < top_width; ++w) { + const Dtype* weight_data = weight_data_base + + c * kernel_h_ * kernel_w_; Dtype value = 0; - for (int kh = 0; kh < kernel_h_; ++kh) - { - for (int kw = 0; kw < kernel_w_; ++kw) - { + for (int kh = 0; kh < kernel_h_; ++kh) { + for (int kw = 0; kw < kernel_w_; ++kw) { int h_in = -pad_h_ + h * stride_h_ + kh * dilation_h_; int w_in = -pad_w_ + w * stride_w_ + kw * dilation_w_; - if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) - { - int offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + if ((h_in >= 0) && (h_in < bottom_height) + && (w_in >= 0) && (w_in < bottom_width)) { + int offset = ((n * channels + c) * bottom_height + h_in) + * bottom_width + w_in; value += (*weight_data) * bottom_data[offset]; } ++weight_data; @@ -183,18 +167,13 @@ void ConvolutionDepthwiseLayer::Forward_cpu(const vector*>& b } } } - if (this->layer_param_.convolution_param().bias_term()) - { + if (this->layer_param_.convolution_param().bias_term()) { top_data = top[0]->mutable_cpu_data(); - for (int n = 0; n < num; ++n) - { + for (int n = 0; n < num; ++n) { const Dtype* bias_data = this->blobs_[1]->cpu_data(); - for (int c = 0; c < channels; ++c) - { - for (int h = 0; h < top_height; ++h) - { - for (int w = 0; w < top_width; ++w) - { + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < top_height; ++h) { + for (int w = 0; w < top_width; ++w) { *top_data += *bias_data; ++top_data; } @@ -206,9 +185,9 @@ void ConvolutionDepthwiseLayer::Forward_cpu(const vector*>& b } template -void ConvolutionDepthwiseLayer::Backward_cpu(const vector*>& top, - const vector& propagate_down, const vector*>& bottom) -{ +void ConvolutionDepthwiseLayer::Backward_cpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { const int num = top[0]->num(); const int channels = top[0]->channels(); const int top_height = top[0]->height(); @@ -216,18 +195,14 @@ void ConvolutionDepthwiseLayer::Backward_cpu(const vector*>& const int bottom_height = bottom[0]->height(); const int bottom_width = bottom[0]->width(); caffe_set(bottom[0]->count(), Dtype(0), bottom[0]->mutable_cpu_diff()); - if (this->layer_param_.convolution_param().bias_term() && this->param_propagate_down_[1]) - { + if (this->layer_param_.convolution_param().bias_term() + && this->param_propagate_down_[1]) { const Dtype* top_diff = top[0]->cpu_diff(); - for (int n = 0; n < num; ++n) - { + for (int n = 0; n < num; ++n) { Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); - for (int c = 0; c < channels; ++c) - { - for (int h = 0; h < top_height; ++h) - { - for (int w = 0; w < top_width; ++w) - { + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < top_height; ++h) { + for (int w = 0; w < top_width; ++w) { *bias_diff += *top_diff; ++top_diff; } @@ -236,29 +211,23 @@ void ConvolutionDepthwiseLayer::Backward_cpu(const vector*>& } } } - if (this->param_propagate_down_[0]) - { + if (this->param_propagate_down_[0]) { const Dtype* top_diff = top[0]->cpu_diff(); const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* weight_diff_base = this->blobs_[0]->mutable_cpu_diff(); - for (int n = 0; n < num; ++n) - { - for (int c = 0; c < channels; ++c) - { - for (int h = 0; h < top_height; ++h) - { - for (int w = 0; w < top_width; ++w) - { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < top_height; ++h) { + for (int w = 0; w < top_width; ++w) { Dtype* weight_diff = weight_diff_base + c * kernel_h_ * kernel_w_; - for (int kh = 0; kh < kernel_h_; ++kh) - { - for (int kw = 0; kw < kernel_w_; ++kw) - { + for (int kh = 0; kh < kernel_h_; ++kh) { + for (int kw = 0; kw < kernel_w_; ++kw) { int h_in = -pad_h_ + h * stride_h_ + kh * dilation_h_; int w_in = -pad_w_ + w * stride_w_ + kw * dilation_w_; - if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) - { - int offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + if ((h_in >= 0) && (h_in < bottom_height) + && (w_in >= 0) && (w_in < bottom_width)) { + int offset = ((n * channels + c) * bottom_height + h_in) + * bottom_width + w_in; *weight_diff += bottom_data[offset] * (*top_diff); } ++weight_diff; @@ -270,29 +239,24 @@ void ConvolutionDepthwiseLayer::Backward_cpu(const vector*>& } } } - if (propagate_down[0]) - { + if (propagate_down[0]) { const Dtype* top_diff = top[0]->cpu_diff(); const Dtype* weight_data_base = this->blobs_[0]->cpu_data(); Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); - for (int n = 0; n < num; ++n) - { - for (int c = 0; c < channels; ++c) - { - for (int h = 0; h < top_height; ++h) - { - for (int w = 0; w < top_width; ++w) - { - const Dtype* weight_data = weight_data_base + c * kernel_h_ * kernel_w_; - for (int kh = 0; kh < kernel_h_; ++kh) - { - for (int kw = 0; kw < kernel_w_; ++kw) - { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < top_height; ++h) { + for (int w = 0; w < top_width; ++w) { + const Dtype* weight_data = weight_data_base + + c * kernel_h_ * kernel_w_; + for (int kh = 0; kh < kernel_h_; ++kh) { + for (int kw = 0; kw < kernel_w_; ++kw) { int h_in = -pad_h_ + h * stride_h_ + kh * dilation_h_; int w_in = -pad_w_ + w * stride_w_ + kw * dilation_w_; - if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) - { - int offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + if ((h_in >= 0) && (h_in < bottom_height) + && (w_in >= 0) && (w_in < bottom_width)) { + int offset = ((n * channels + c) * bottom_height + h_in) + * bottom_width + w_in; bottom_diff[offset] += (*weight_data) * (*top_diff); } ++weight_data; diff --git a/src/caffe/layers/conv_dw_layer.cu b/src/caffe/layers/conv_dw_layer.cu index ed617495070..ea0dd19a75d 100644 --- a/src/caffe/layers/conv_dw_layer.cu +++ b/src/caffe/layers/conv_dw_layer.cu @@ -6,11 +6,12 @@ namespace caffe { template __global__ void ConvolutionDepthwiseWeightForward(const int nthreads, - const Dtype* const bottom_data, const Dtype* const weight_data, const int num, const int channels, - const int top_height, const int top_width, const int bottom_height, const int bottom_width, - const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, - const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, - Dtype* const top_data) { + const Dtype* const bottom_data, const Dtype* const weight_data, + const int num, const int channels, const int top_height, + const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, Dtype* const top_data) { CUDA_KERNEL_LOOP(index, nthreads) { const int n = index / channels / top_height / top_width; const int c = (index / top_height / top_width) % channels; @@ -18,15 +19,14 @@ __global__ void ConvolutionDepthwiseWeightForward(const int nthreads, const int w = index % top_width; const Dtype* weight = weight_data + c * kernel_h * kernel_w; Dtype value = 0; - for (int kh = 0; kh < kernel_h; ++kh) - { - for (int kw = 0; kw < kernel_w; ++kw) - { + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { const int h_in = -pad_h + h * stride_h + kh * dilation_h; const int w_in = -pad_w + w * stride_w + kw * dilation_w; - if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) - { - const int offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + if ((h_in >= 0) && (h_in < bottom_height) + && (w_in >= 0) && (w_in < bottom_width)) { + const int offset = ((n * channels + c) * bottom_height + h_in) + * bottom_width + w_in; value += (*weight) * bottom_data[offset]; } ++weight; @@ -47,8 +47,8 @@ __global__ void ConvolutionDepthwiseBiasForward(const int nthreads, } template -void ConvolutionDepthwiseLayer::Forward_gpu(const vector*>& bottom, - const vector*>& top) { +void ConvolutionDepthwiseLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); const Dtype* weight_data = this->blobs_[0]->gpu_data(); @@ -59,15 +59,16 @@ void ConvolutionDepthwiseLayer::Forward_gpu(const vector*>& b const int top_width = top[0]->width(); const int bottom_height = bottom[0]->height(); const int bottom_width = bottom[0]->width(); - ConvolutionDepthwiseWeightForward<<>>( + ConvolutionDepthwiseWeightForward<<>>( count, bottom_data, weight_data, num, channels, top_height, top_width, bottom_height, bottom_width, kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, dilation_h_, dilation_w_, top_data); - if (this->layer_param_.convolution_param().bias_term()) - { + if (this->layer_param_.convolution_param().bias_term()) { const Dtype* bias_data = this->blobs_[1]->gpu_data(); - ConvolutionDepthwiseBiasForward<<>>( + ConvolutionDepthwiseBiasForward<<>>( count, bias_data, num, channels, top_height, top_width, top_data); } @@ -75,11 +76,12 @@ void ConvolutionDepthwiseLayer::Forward_gpu(const vector*>& b template __global__ void ConvolutionDepthwiseWeightBackward(const int nthreads, - const Dtype* const top_diff, const Dtype* const bottom_data, const int num, const int channels, - const int top_height, const int top_width, const int bottom_height, const int bottom_width, - const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, - const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, - Dtype* const buffer_data) { + const Dtype* const top_diff, const Dtype* const bottom_data, + const int num, const int channels, const int top_height, + const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, Dtype* const buffer_data) { CUDA_KERNEL_LOOP(index, nthreads) { const int h = (index / top_width) % top_height; const int w = index % top_width; @@ -87,16 +89,14 @@ __global__ void ConvolutionDepthwiseWeightBackward(const int nthreads, const int kw = (index / num / top_height / top_width) % kernel_w; const int h_in = -pad_h + h * stride_h + kh * dilation_h; const int w_in = -pad_w + w * stride_w + kw * dilation_w; - if ((h_in >= 0) && (h_in < bottom_height) && (w_in >= 0) && (w_in < bottom_width)) - { + if ((h_in >= 0) && (h_in < bottom_height) + && (w_in >= 0) && (w_in < bottom_width)) { const int c = index / kernel_h / kernel_w / num / top_height / top_width; const int n = (index / top_height / top_width) % num; const int top_offset = ((n * channels + c) * top_height + h) * top_width + w; const int bottom_offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; buffer_data[index] = top_diff[top_offset] * bottom_data[bottom_offset]; - } - else - { + } else { buffer_data[index] = 0; } } @@ -104,11 +104,12 @@ __global__ void ConvolutionDepthwiseWeightBackward(const int nthreads, template __global__ void ConvolutionDepthwiseBottomBackward(const int nthreads, - const Dtype* const top_diff, const Dtype* const weight_data, const int num, const int channels, - const int top_height, const int top_width, const int bottom_height, const int bottom_width, - const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, - const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, - Dtype* const bottom_diff) { + const Dtype* const top_diff, const Dtype* const weight_data, + const int num, const int channels, const int top_height, + const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, Dtype* const bottom_diff) { CUDA_KERNEL_LOOP(index, nthreads) { const int n = index / channels / bottom_height / bottom_width; const int c = (index / bottom_height / bottom_width) % channels; @@ -116,19 +117,17 @@ __global__ void ConvolutionDepthwiseBottomBackward(const int nthreads, const int w = index % bottom_width; const Dtype* weight = weight_data + c * kernel_h * kernel_w; Dtype value = 0; - for (int kh = 0; kh < kernel_h; ++kh) - { - for (int kw = 0; kw < kernel_w; ++kw) - { + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { const int h_out_s = h + pad_h - kh * dilation_h; const int w_out_s = w + pad_w - kw * dilation_w; - if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) - { + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { const int h_out = h_out_s / stride_h; const int w_out = w_out_s / stride_w; - if ((h_out >= 0) && (h_out < top_height) && (w_out >= 0) && (w_out < top_width)) - { - const int offset = ((n * channels + c) * top_height + h_out) * top_width + w_out; + if ((h_out >= 0) && (h_out < top_height) + && (w_out >= 0) && (w_out < top_width)) { + const int offset = ((n * channels + c) * top_height + h_out) + * top_width + w_out; value += (*weight) * top_diff[offset]; } } @@ -154,8 +153,9 @@ __global__ void ConvolutionDepthwiseBiasBackward(const int nthreads, } template -void ConvolutionDepthwiseLayer::Backward_gpu(const vector*>& top, - const vector& propagate_down, const vector*>& bottom) { +void ConvolutionDepthwiseLayer::Backward_gpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { const Dtype* top_diff = top[0]->gpu_diff(); const int bottom_count = bottom[0]->count(); const int num = top[0]->num(); @@ -166,25 +166,27 @@ void ConvolutionDepthwiseLayer::Backward_gpu(const vector*>& const int bottom_width = bottom[0]->width(); const int length = num * top_height * top_width; caffe_gpu_set(bottom_count, Dtype(0), bottom[0]->mutable_gpu_diff()); - if (this->layer_param_.convolution_param().bias_term() && this->param_propagate_down_[1]) - { + if (this->layer_param_.convolution_param().bias_term() + && this->param_propagate_down_[1]) { const int bias_buffer_count = bias_buffer_.count(); Dtype* bias_buffer_mutable_data = bias_buffer_.mutable_gpu_data(); - ConvolutionDepthwiseBiasBackward<<>>( + ConvolutionDepthwiseBiasBackward<<>>( bias_buffer_count, top_diff, num, channels, top_height, top_width, bias_buffer_mutable_data); const int bias_count = this->blobs_[1]->count(); const Dtype* bias_buffer_data = bias_buffer_.gpu_data(); Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); const Dtype* bias_multiplier_data = bias_multiplier_.gpu_data(); - caffe_gpu_gemv(CblasNoTrans, bias_count, length, Dtype(1), bias_buffer_data, bias_multiplier_data, Dtype(1), bias_diff); + caffe_gpu_gemv(CblasNoTrans, bias_count, length, Dtype(1), + bias_buffer_data, bias_multiplier_data, Dtype(1), bias_diff); } - if (this->param_propagate_down_[0]) - { + if (this->param_propagate_down_[0]) { const int weight_buffer_count = weight_buffer_.count(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* weight_buffer_mutable_data = weight_buffer_.mutable_gpu_data(); - ConvolutionDepthwiseWeightBackward<<>>( + ConvolutionDepthwiseWeightBackward<<>>( weight_buffer_count, top_diff, bottom_data, num, channels, top_height, top_width, bottom_height, bottom_width, kernel_h_, kernel_w_, stride_h_, stride_w_, @@ -193,13 +195,14 @@ void ConvolutionDepthwiseLayer::Backward_gpu(const vector*>& const Dtype* weight_buffer_data = weight_buffer_.gpu_data(); Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); const Dtype* weight_multiplier_data = weight_multiplier_.gpu_data(); - caffe_gpu_gemv(CblasNoTrans, weight_count, length, Dtype(1), weight_buffer_data, weight_multiplier_data, Dtype(1), weight_diff); + caffe_gpu_gemv(CblasNoTrans, weight_count, length, Dtype(1), + weight_buffer_data, weight_multiplier_data, Dtype(1), weight_diff); } - if (propagate_down[0]) - { + if (propagate_down[0]) { const Dtype* weight_data = this->blobs_[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - ConvolutionDepthwiseBottomBackward<<>>( + ConvolutionDepthwiseBottomBackward<<>>( bottom_count, top_diff, weight_data, num, channels, top_height, top_width, bottom_height, bottom_width, kernel_h_, kernel_w_, stride_h_, stride_w_, From 3d0182cdbc888f510873afc90ace473dcc6719c8 Mon Sep 17 00:00:00 2001 From: sp2823 Date: Fri, 2 Jun 2017 15:58:52 +0800 Subject: [PATCH 4/9] satisfy the code format of caffe --- src/caffe/layers/conv_dw_layer.cu | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/caffe/layers/conv_dw_layer.cu b/src/caffe/layers/conv_dw_layer.cu index ea0dd19a75d..9f9220f1d83 100644 --- a/src/caffe/layers/conv_dw_layer.cu +++ b/src/caffe/layers/conv_dw_layer.cu @@ -85,7 +85,8 @@ __global__ void ConvolutionDepthwiseWeightBackward(const int nthreads, CUDA_KERNEL_LOOP(index, nthreads) { const int h = (index / top_width) % top_height; const int w = index % top_width; - const int kh = (index / kernel_w / num / top_height / top_width) % kernel_h; + const int kh = (index / kernel_w / num / top_height / top_width) + % kernel_h; const int kw = (index / num / top_height / top_width) % kernel_w; const int h_in = -pad_h + h * stride_h + kh * dilation_h; const int w_in = -pad_w + w * stride_w + kw * dilation_w; @@ -93,8 +94,10 @@ __global__ void ConvolutionDepthwiseWeightBackward(const int nthreads, && (w_in >= 0) && (w_in < bottom_width)) { const int c = index / kernel_h / kernel_w / num / top_height / top_width; const int n = (index / top_height / top_width) % num; - const int top_offset = ((n * channels + c) * top_height + h) * top_width + w; - const int bottom_offset = ((n * channels + c) * bottom_height + h_in) * bottom_width + w_in; + const int top_offset = ((n * channels + c) * top_height + h) + * top_width + w; + const int bottom_offset = ((n * channels + c) * bottom_height + h_in) + * bottom_width + w_in; buffer_data[index] = top_diff[top_offset] * bottom_data[bottom_offset]; } else { buffer_data[index] = 0; From 0a6ff1ef72e3a3bea09acfbb1f381996d1e84eae Mon Sep 17 00:00:00 2001 From: sp2823 Date: Fri, 2 Jun 2017 17:03:47 +0800 Subject: [PATCH 5/9] satisfy the code format of caffe --- src/caffe/layers/conv_dw_layer.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/caffe/layers/conv_dw_layer.cu b/src/caffe/layers/conv_dw_layer.cu index 9f9220f1d83..dfeafb38f82 100644 --- a/src/caffe/layers/conv_dw_layer.cu +++ b/src/caffe/layers/conv_dw_layer.cu @@ -59,7 +59,7 @@ void ConvolutionDepthwiseLayer::Forward_gpu( const int top_width = top[0]->width(); const int bottom_height = bottom[0]->height(); const int bottom_width = bottom[0]->width(); - ConvolutionDepthwiseWeightForward<< <<>>( count, bottom_data, weight_data, num, channels, top_height, top_width, bottom_height, bottom_width, @@ -67,7 +67,7 @@ void ConvolutionDepthwiseLayer::Forward_gpu( pad_h_, pad_w_, dilation_h_, dilation_w_, top_data); if (this->layer_param_.convolution_param().bias_term()) { const Dtype* bias_data = this->blobs_[1]->gpu_data(); - ConvolutionDepthwiseBiasForward<< <<>>( count, bias_data, num, channels, top_height, top_width, top_data); @@ -173,7 +173,7 @@ void ConvolutionDepthwiseLayer::Backward_gpu( && this->param_propagate_down_[1]) { const int bias_buffer_count = bias_buffer_.count(); Dtype* bias_buffer_mutable_data = bias_buffer_.mutable_gpu_data(); - ConvolutionDepthwiseBiasBackward<< <<>>( bias_buffer_count, top_diff, num, channels, top_height, top_width, bias_buffer_mutable_data); @@ -188,7 +188,7 @@ void ConvolutionDepthwiseLayer::Backward_gpu( const int weight_buffer_count = weight_buffer_.count(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* weight_buffer_mutable_data = weight_buffer_.mutable_gpu_data(); - ConvolutionDepthwiseWeightBackward<< <<>>( weight_buffer_count, top_diff, bottom_data, num, channels, top_height, top_width, bottom_height, bottom_width, @@ -204,7 +204,7 @@ void ConvolutionDepthwiseLayer::Backward_gpu( if (propagate_down[0]) { const Dtype* weight_data = this->blobs_[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - ConvolutionDepthwiseBottomBackward<< <<>>( bottom_count, top_diff, weight_data, num, channels, top_height, top_width, bottom_height, bottom_width, From e51f834931930b80982741bdb72d962ce741bb16 Mon Sep 17 00:00:00 2001 From: sp2823 Date: Fri, 2 Jun 2017 18:42:06 +0800 Subject: [PATCH 6/9] satisfy the code format of caffe --- src/caffe/layers/conv_dw_layer.cpp | 2 +- src/caffe/layers/conv_dw_layer.cu | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/caffe/layers/conv_dw_layer.cpp b/src/caffe/layers/conv_dw_layer.cpp index 2cc742f8de5..7e8e96616b8 100644 --- a/src/caffe/layers/conv_dw_layer.cpp +++ b/src/caffe/layers/conv_dw_layer.cpp @@ -1,8 +1,8 @@ #include #include #include "caffe/filler.hpp" -#include "caffe/util/math_functions.hpp" #include "caffe/layers/conv_dw_layer.hpp" +#include "caffe/util/math_functions.hpp" namespace caffe { diff --git a/src/caffe/layers/conv_dw_layer.cu b/src/caffe/layers/conv_dw_layer.cu index dfeafb38f82..927e9920d9e 100644 --- a/src/caffe/layers/conv_dw_layer.cu +++ b/src/caffe/layers/conv_dw_layer.cu @@ -59,16 +59,16 @@ void ConvolutionDepthwiseLayer::Forward_gpu( const int top_width = top[0]->width(); const int bottom_height = bottom[0]->height(); const int bottom_width = bottom[0]->width(); - ConvolutionDepthwiseWeightForward <<>>( + ConvolutionDepthwiseWeightForward + <<>>( count, bottom_data, weight_data, num, channels, top_height, top_width, bottom_height, bottom_width, kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, dilation_h_, dilation_w_, top_data); if (this->layer_param_.convolution_param().bias_term()) { const Dtype* bias_data = this->blobs_[1]->gpu_data(); - ConvolutionDepthwiseBiasForward <<>>( + ConvolutionDepthwiseBiasForward + <<>>( count, bias_data, num, channels, top_height, top_width, top_data); } @@ -173,8 +173,8 @@ void ConvolutionDepthwiseLayer::Backward_gpu( && this->param_propagate_down_[1]) { const int bias_buffer_count = bias_buffer_.count(); Dtype* bias_buffer_mutable_data = bias_buffer_.mutable_gpu_data(); - ConvolutionDepthwiseBiasBackward <<>>( + ConvolutionDepthwiseBiasBackward + <<>>( bias_buffer_count, top_diff, num, channels, top_height, top_width, bias_buffer_mutable_data); const int bias_count = this->blobs_[1]->count(); @@ -188,8 +188,8 @@ void ConvolutionDepthwiseLayer::Backward_gpu( const int weight_buffer_count = weight_buffer_.count(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* weight_buffer_mutable_data = weight_buffer_.mutable_gpu_data(); - ConvolutionDepthwiseWeightBackward <<>>( + ConvolutionDepthwiseWeightBackward + <<>>( weight_buffer_count, top_diff, bottom_data, num, channels, top_height, top_width, bottom_height, bottom_width, kernel_h_, kernel_w_, stride_h_, stride_w_, @@ -204,8 +204,8 @@ void ConvolutionDepthwiseLayer::Backward_gpu( if (propagate_down[0]) { const Dtype* weight_data = this->blobs_[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - ConvolutionDepthwiseBottomBackward <<>>( + ConvolutionDepthwiseBottomBackward + <<>>( bottom_count, top_diff, weight_data, num, channels, top_height, top_width, bottom_height, bottom_width, kernel_h_, kernel_w_, stride_h_, stride_w_, From d40fbd8a994f1a68405101d78d25bddaf767f764 Mon Sep 17 00:00:00 2001 From: sp2823 Date: Fri, 2 Jun 2017 18:58:22 +0800 Subject: [PATCH 7/9] satisfy the code format of caffe --- src/caffe/layers/conv_dw_layer.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/caffe/layers/conv_dw_layer.cu b/src/caffe/layers/conv_dw_layer.cu index 927e9920d9e..3a4f5af5b8d 100644 --- a/src/caffe/layers/conv_dw_layer.cu +++ b/src/caffe/layers/conv_dw_layer.cu @@ -60,6 +60,7 @@ void ConvolutionDepthwiseLayer::Forward_gpu( const int bottom_height = bottom[0]->height(); const int bottom_width = bottom[0]->width(); ConvolutionDepthwiseWeightForward + // NOLINT_NEXT_LINE(whitespace/operators) <<>>( count, bottom_data, weight_data, num, channels, top_height, top_width, bottom_height, bottom_width, @@ -68,6 +69,7 @@ void ConvolutionDepthwiseLayer::Forward_gpu( if (this->layer_param_.convolution_param().bias_term()) { const Dtype* bias_data = this->blobs_[1]->gpu_data(); ConvolutionDepthwiseBiasForward + // NOLINT_NEXT_LINE(whitespace/operators) <<>>( count, bias_data, num, channels, top_height, top_width, top_data); @@ -174,6 +176,7 @@ void ConvolutionDepthwiseLayer::Backward_gpu( const int bias_buffer_count = bias_buffer_.count(); Dtype* bias_buffer_mutable_data = bias_buffer_.mutable_gpu_data(); ConvolutionDepthwiseBiasBackward + // NOLINT_NEXT_LINE(whitespace/operators) <<>>( bias_buffer_count, top_diff, num, channels, top_height, top_width, bias_buffer_mutable_data); @@ -189,6 +192,7 @@ void ConvolutionDepthwiseLayer::Backward_gpu( const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* weight_buffer_mutable_data = weight_buffer_.mutable_gpu_data(); ConvolutionDepthwiseWeightBackward + // NOLINT_NEXT_LINE(whitespace/operators) <<>>( weight_buffer_count, top_diff, bottom_data, num, channels, top_height, top_width, bottom_height, bottom_width, @@ -205,6 +209,7 @@ void ConvolutionDepthwiseLayer::Backward_gpu( const Dtype* weight_data = this->blobs_[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); ConvolutionDepthwiseBottomBackward + // NOLINT_NEXT_LINE(whitespace/operators) <<>>( bottom_count, top_diff, weight_data, num, channels, top_height, top_width, bottom_height, bottom_width, From 013377e8053dfd852558647b4ee201f3540209d9 Mon Sep 17 00:00:00 2001 From: sp2823 Date: Thu, 15 Jun 2017 00:33:07 +0800 Subject: [PATCH 8/9] abc --- src/caffe/layers/conv_dw_layer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/caffe/layers/conv_dw_layer.cpp b/src/caffe/layers/conv_dw_layer.cpp index 7e8e96616b8..a0ee6b8eae2 100644 --- a/src/caffe/layers/conv_dw_layer.cpp +++ b/src/caffe/layers/conv_dw_layer.cpp @@ -112,7 +112,7 @@ void ConvolutionDepthwiseLayer::Reshape( weight_multiplier_shape.push_back(top[0]->width()); weight_multiplier_.Reshape(weight_multiplier_shape); caffe_set(weight_multiplier_.count(), Dtype(1), - weight_multiplier_.mutable_gpu_data()); + weight_multiplier_.mutable_cpu_data()); if (this->layer_param_.convolution_param().bias_term()) { vector bias_buffer_shape; bias_buffer_shape.push_back(bottom[0]->channels()); @@ -126,7 +126,7 @@ void ConvolutionDepthwiseLayer::Reshape( bias_multiplier_shape.push_back(top[0]->width()); bias_multiplier_.Reshape(bias_multiplier_shape); caffe_set(bias_multiplier_.count(), Dtype(1), - bias_multiplier_.mutable_gpu_data()); + bias_multiplier_.mutable_cpu_data()); } } From 327a0194c67bc599ade211c388087a166339bdb5 Mon Sep 17 00:00:00 2001 From: sp2823 Date: Tue, 27 Jun 2017 18:09:29 +0800 Subject: [PATCH 9/9] unknown error --- include/caffe/layers/conv_dw_layer.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/caffe/layers/conv_dw_layer.hpp b/include/caffe/layers/conv_dw_layer.hpp index 8b5133ada24..02dadc0d463 100644 --- a/include/caffe/layers/conv_dw_layer.hpp +++ b/include/caffe/layers/conv_dw_layer.hpp @@ -46,4 +46,5 @@ class ConvolutionDepthwiseLayer : public Layer { } // namespace caffe + #endif // CAFFE_CONV_DW_LAYER_HPP_