From e4d48c560de0bb70be84054b098984293434a123 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Wed, 17 Sep 2014 15:10:18 -0700 Subject: [PATCH 1/3] test convolution against explicit reference implementation To thoroughly check convolution, the output is compared against a reference implementation by explicit looping. Simple and group convolution by the Caffe and cuDNN engines are checked against the reference. --- src/caffe/test/test_convolution_layer.cpp | 217 +++++++++++++++------- 1 file changed, 149 insertions(+), 68 deletions(-) diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index 0e7a8da567b..0849a3a4591 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -13,6 +13,93 @@ namespace caffe { +// Reference convolution for checking results: +// accumulate through explicit loops over input, output, and filters. +template +void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, + const vector > >& weights, + Blob* out) { + // Kernel size, stride, and pad + int kernel_h, kernel_w; + if (conv_param->has_kernel_size()) { + kernel_h = kernel_w = conv_param->kernel_size(); + } else { + kernel_h = conv_param->kernel_h(); + kernel_w = conv_param->kernel_w(); + } + int pad_h, pad_w; + if (!conv_param->has_pad_h()) { + pad_h = pad_w = conv_param->pad(); + } else { + pad_h = conv_param->pad_h(); + pad_w = conv_param->pad_w(); + } + int stride_h, stride_w; + if (!conv_param->has_stride_h()) { + stride_h = stride_w = conv_param->stride(); + } else { + stride_h = conv_param->stride_h(); + stride_w = conv_param->stride_w(); + } + // Groups + int groups = conv_param->group(); + int o_g = out->channels() / groups; + int k_g = in->channels() / groups; + int o_head, k_head; + // Convolution + const Dtype* in_data = in->cpu_data(); + const Dtype* weight_data = weights[0]->cpu_data(); + Dtype* out_data = out->mutable_cpu_data(); + for (int n = 0; n < out->num(); n++) { + for (int g = 0; g < groups; g++) { + o_head = o_g * g; + k_head = k_g * g; + for (int o = 0; o < o_g; o++) { + for (int k = 0; k < k_g; k++) { + for (int y = 0; y < out->height(); y++) { + for (int x = 0; x < out->width(); x++) { + for (int p = 0; p < kernel_h; p++) { + for (int q = 0; q < kernel_w; q++) { + int in_y = y * stride_h - pad_h + p; + int in_x = x * stride_w - pad_w + q; + if (in_y >= 0 && in_y < in->height() + && in_x >= 0 && in_x < in->width()) { + out_data[out->offset(n, o + o_head, y, x)] += + in_data[in->offset(n, k + k_head, in_y, in_x)] + * weight_data[weights[0]->offset(o + o_head, k, p, q)]; + } + } + } + } + } + } + } + } + } + // Bias + if (conv_param->bias_term()) { + const Dtype* bias_data = weights[1]->cpu_data(); + for (int n = 0; n < out->num(); n++) { + for (int o = 0; o < out->channels(); o++) { + for (int y = 0; y < out->height(); y++) { + for (int x = 0; x < out->width(); x++) { + out_data[out->offset(n, o, y, x)] += bias_data[o]; + } + } + } + } + } +} + +template void caffe_conv(const Blob* in, + ConvolutionParameter* conv_param, + const vector > >& weights, + Blob* out); +template void caffe_conv(const Blob* in, + ConvolutionParameter* conv_param, + const vector > >& weights, + Blob* out); + template class ConvolutionLayerTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; @@ -41,10 +128,17 @@ class ConvolutionLayerTest : public MultiDeviceTest { delete blob_top_2_; } + virtual Blob* MakeReferenceTop(Blob* top) { + this->ref_blob_top_.reset(new Blob()); + this->ref_blob_top_->ReshapeLike(*top); + return this->ref_blob_top_.get(); + } + Blob* const blob_bottom_; Blob* const blob_bottom_2_; Blob* const blob_top_; Blob* const blob_top_2_; + shared_ptr > ref_blob_top_; vector*> blob_bottom_vec_; vector*> blob_top_vec_; }; @@ -90,14 +184,6 @@ TYPED_TEST(ConvolutionLayerTest, TestSetup) { TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { // We will simply see if the convolution layer carries out averaging well. typedef typename TypeParam::Dtype Dtype; - shared_ptr > filler; - FillerParameter filler_param; - filler_param.set_value(1.); - filler.reset(new ConstantFiller(filler_param)); - filler->Fill(this->blob_bottom_); - filler_param.set_value(2.); - filler.reset(new ConstantFiller(filler_param)); - filler->Fill(this->blob_bottom_2_); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); LayerParameter layer_param; @@ -114,34 +200,28 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { new ConvolutionLayer(layer_param)); layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); - // After the convolution, the output should all have output values 27.1 - const Dtype* top_data = this->blob_top_->cpu_data(); + // Check against reference convolution. + const Dtype* top_data; + const Dtype* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); for (int i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], 27.1, 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); } + caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); - for (int i = 0; i < this->blob_top_2_->count(); ++i) { - EXPECT_NEAR(top_data[i], 54.1, 1e-4); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); } } TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { // We will simply see if the convolution layer carries out averaging well. typedef typename TypeParam::Dtype Dtype; - FillerParameter filler_param; - filler_param.set_value(1.); - ConstantFiller filler(filler_param); - filler.Fill(this->blob_bottom_); - Dtype* bottom_data = this->blob_bottom_->mutable_cpu_data(); - for (int n = 0; n < this->blob_bottom_->num(); ++n) { - for (int c = 0; c < this->blob_bottom_->channels(); ++c) { - for (int h = 0; h < this->blob_bottom_->height(); ++h) { - for (int w = 0; w < this->blob_bottom_->width(); ++w) { - bottom_data[this->blob_bottom_->offset(n, c, h, w)] = c; - } - } - } - } LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); @@ -157,17 +237,15 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { new ConvolutionLayer(layer_param)); layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); - // After the convolution, the output should all have output values 9.1 - const Dtype* top_data = this->blob_top_->cpu_data(); - for (int n = 0; n < this->blob_top_->num(); ++n) { - for (int c = 0; c < this->blob_top_->channels(); ++c) { - for (int h = 0; h < this->blob_top_->height(); ++h) { - for (int w = 0; w < this->blob_top_->width(); ++w) { - Dtype data = top_data[this->blob_top_->offset(n, c, h, w)]; - EXPECT_NEAR(data, c * 9 + 0.1, 1e-4); - } - } - } + // Check against reference convolution. + const Dtype* top_data; + const Dtype* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); } } @@ -330,10 +408,17 @@ class CuDNNConvolutionLayerTest : public ::testing::Test { delete blob_top_2_; } + virtual Blob* MakeReferenceTop(Blob* top) { + this->ref_blob_top_.reset(new Blob()); + this->ref_blob_top_->ReshapeLike(*top); + return this->ref_blob_top_.get(); + } + Blob* const blob_bottom_; Blob* const blob_bottom_2_; Blob* const blob_top_; Blob* const blob_top_2_; + shared_ptr > ref_blob_top_; vector*> blob_bottom_vec_; vector*> blob_top_vec_; }; @@ -342,6 +427,8 @@ TYPED_TEST_CASE(CuDNNConvolutionLayerTest, TestDtypes); TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) { Caffe::set_mode(Caffe::GPU); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); @@ -379,6 +466,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) { TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) { // We will simply see if the convolution layer carries out averaging well. Caffe::set_mode(Caffe::GPU); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); shared_ptr > filler; FillerParameter filler_param; filler_param.set_value(1.); @@ -403,34 +492,28 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) { new CuDNNConvolutionLayer(layer_param)); layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); - // After the convolution, the output should all have output values 27.1 - const TypeParam* top_data = this->blob_top_->cpu_data(); + // Check against reference convolution. + const TypeParam* top_data; + const TypeParam* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); for (int i = 0; i < this->blob_top_->count(); ++i) { - EXPECT_NEAR(top_data[i], 27.1, 1e-4); + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); } + caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_2_)); top_data = this->blob_top_2_->cpu_data(); - for (int i = 0; i < this->blob_top_2_->count(); ++i) { - EXPECT_NEAR(top_data[i], 54.1, 1e-4); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); } } TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) { // We will simply see if the convolution layer carries out averaging well. Caffe::set_mode(Caffe::GPU); - FillerParameter filler_param; - filler_param.set_value(1.); - ConstantFiller filler(filler_param); - filler.Fill(this->blob_bottom_); - TypeParam* bottom_data = this->blob_bottom_->mutable_cpu_data(); - for (int n = 0; n < this->blob_bottom_->num(); ++n) { - for (int c = 0; c < this->blob_bottom_->channels(); ++c) { - for (int h = 0; h < this->blob_bottom_->height(); ++h) { - for (int w = 0; w < this->blob_bottom_->width(); ++w) { - bottom_data[this->blob_bottom_->offset(n, c, h, w)] = c; - } - } - } - } LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); @@ -446,17 +529,15 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) { new CuDNNConvolutionLayer(layer_param)); layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); - // After the convolution, the output should all have output values 9.1 - const TypeParam* top_data = this->blob_top_->cpu_data(); - for (int n = 0; n < this->blob_top_->num(); ++n) { - for (int c = 0; c < this->blob_top_->channels(); ++c) { - for (int h = 0; h < this->blob_top_->height(); ++h) { - for (int w = 0; w < this->blob_top_->width(); ++w) { - TypeParam data = top_data[this->blob_top_->offset(n, c, h, w)]; - EXPECT_NEAR(data, c * 9 + 0.1, 1e-4); - } - } - } + // Check against reference convolution. + const TypeParam* top_data; + const TypeParam* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); } } From 355af161c2d04f474b5f9a26fef7c2822e3b88bc Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Thu, 18 Sep 2014 08:40:00 -0700 Subject: [PATCH 2/3] test convolution by random weights for robustness --- src/caffe/test/test_convolution_layer.cpp | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index 0849a3a4591..a38ad3fd1a8 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -192,8 +192,7 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { convolution_param->set_kernel_size(3); convolution_param->set_stride(2); convolution_param->set_num_output(4); - convolution_param->mutable_weight_filler()->set_type("constant"); - convolution_param->mutable_weight_filler()->set_value(1); + convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); convolution_param->mutable_bias_filler()->set_value(0.1); shared_ptr > layer( @@ -229,8 +228,7 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { convolution_param->set_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); - convolution_param->mutable_weight_filler()->set_type("constant"); - convolution_param->mutable_weight_filler()->set_value(1); + convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); convolution_param->mutable_bias_filler()->set_value(0.1); shared_ptr > layer( @@ -468,24 +466,13 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) { Caffe::set_mode(Caffe::GPU); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - shared_ptr > filler; - FillerParameter filler_param; - filler_param.set_value(1.); - filler.reset(new ConstantFiller(filler_param)); - filler->Fill(this->blob_bottom_); - filler_param.set_value(2.); - filler.reset(new ConstantFiller(filler_param)); - filler->Fill(this->blob_bottom_2_); - this->blob_bottom_vec_.push_back(this->blob_bottom_2_); - this->blob_top_vec_.push_back(this->blob_top_2_); LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); convolution_param->set_kernel_size(3); convolution_param->set_stride(2); convolution_param->set_num_output(4); - convolution_param->mutable_weight_filler()->set_type("constant"); - convolution_param->mutable_weight_filler()->set_value(1); + convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); convolution_param->mutable_bias_filler()->set_value(0.1); shared_ptr > layer( @@ -521,8 +508,7 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) { convolution_param->set_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); - convolution_param->mutable_weight_filler()->set_type("constant"); - convolution_param->mutable_weight_filler()->set_value(1); + convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); convolution_param->mutable_bias_filler()->set_value(0.1); shared_ptr > layer( From 18ca3625c6accd021a4e05e6fdb77c8b6b06dc8e Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Wed, 17 Sep 2014 23:38:31 -0700 Subject: [PATCH 3/3] [docs] comment ConvolutionLayer --- include/caffe/vision_layers.hpp | 62 +++++++++++++++++++++++++++++++-- src/caffe/layers/conv_layer.cpp | 58 +++++++++++++++++------------- src/caffe/layers/conv_layer.cu | 9 +++-- 3 files changed, 98 insertions(+), 31 deletions(-) diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 8c2db8824ed..9c8656f1f2b 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -20,11 +20,49 @@ namespace caffe { * @brief Convolves the input image with a bank of learned filters, * and (optionally) adds biases. * - * TODO(dox): thorough documentation for Forward, Backward, and proto params. + * Caffe convolves by reduction to matrix multiplication. This achieves + * high-throughput and generality of input and filter dimensions but comes at + * the cost of memory for matrices. This makes use of efficiency in BLAS. + * + * The input is "im2col" transformed to a channel K' x H x W data matrix + * for multiplication with the N x K' x H x W filter matrix to yield a + * N' x H x W output matrix that is then "col2im" restored. K' is the + * input channel * kernel height * kernel width dimension of the unrolled + * inputs so that the im2col matrix has a column for each input region to + * be filtered. col2im restores the output spatial structure by rolling up + * the output channel N' columns of the output matrix. */ template class ConvolutionLayer : public Layer { public: + /** + * @param param provides ConvolutionParameter convolution_param, + * with ConvolutionLayer options: + * - num_output. The number of filters. + * - kernel_size / kernel_h / kernel_w. The filter dimensions, given by + * kernel_size for square filters or kernel_h and kernel_w for rectangular + * filters. + * - stride / stride_h / stride_w (\b optional, default 1). The filter + * stride, given by stride_size for equal dimensions or stride_h and stride_w + * for different strides. By default the convolution is dense with stride 1. + * - pad / pad_h / pad_w (\b optional, default 0). The zero-padding for + * convolution, given by pad for equal dimensions or pad_h and pad_w for + * different padding. Input padding is computed implicitly instead of + * actually padding. + * - group (\b optional, default 1). The number of filter groups. Group + * convolution is a method for reducing parameterization by selectively + * connecting input and output channels. The input and output channel dimensions must be divisible + * by the number of groups. For group @f$ \geq 1 @f$, the + * convolutional filters' input and output channels are separated s.t. each + * group takes 1 / group of the input channels and makes 1 / group of the + * output channels. Concretely 4 input channels, 8 output channels, and + * 2 groups separate input channels 1-2 and output channels 1-4 into the + * first group and input channels 3-4 and output channels 5-8 into the second + * group. + * - bias_term (\b optional, default true). Whether to have a bias. + * - engine: convolution has CAFFE (matrix multiplication) and CUDNN (library + * kernels + stream parallelism) engines. + */ explicit ConvolutionLayer(const LayerParameter& param) : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, @@ -57,8 +95,16 @@ class ConvolutionLayer : public Layer { int num_output_; int height_out_, width_out_; bool bias_term_; - // For the Caffe matrix multiplication convolution. - int M_, K_, N_; + + /// M_ is the channel dimension of the output for a single group, which is the + /// leading dimension of the filter matrix. + int M_; + /// K_ is the dimension of an unrolled input for a single group, which is the + /// leading dimension of the data matrix. + int K_; + /// N_ is the spatial dimension of the output, the H x W, which are the last + /// dimensions of the data and filter matrices. + int N_; Blob col_buffer_; Blob bias_multiplier_; }; @@ -67,6 +113,16 @@ class ConvolutionLayer : public Layer { /* * @brief cuDNN implementation of ConvolutionLayer. * Fallback to ConvolutionLayer for CPU mode. + * + * cuDNN accelerates convolution through forward kernels for filtering and bias + * plus backward kernels for the gradient w.r.t. the filters, biases, and + * inputs. Caffe + cuDNN further speeds up the computation through forward + * parallelism across groups and backward parallelism across gradients. + * + * The CUDNN engine does not have memory overhead for matrix buffers. For many + * input and filter regimes the CUDNN engine is faster than the CAFFE engine, + * but for fully-convolutional models and large inputs the CAFFE engine can be + * faster as long as it fits in memory. */ template class CuDNNConvolutionLayer : public ConvolutionLayer { diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index 16b6a688968..769dfa671f6 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -11,6 +11,7 @@ namespace caffe { template void ConvolutionLayer::LayerSetUp(const vector*>& bottom, vector*>* top) { + // Configure the kernel size, padding, stride, and inputs. ConvolutionParameter conv_param = this->layer_param_.convolution_param(); CHECK(!conv_param.has_kernel_size() != !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) @@ -46,7 +47,6 @@ void ConvolutionLayer::LayerSetUp(const vector*>& bottom, stride_h_ = conv_param.stride_h(); stride_w_ = conv_param.stride_w(); } - group_ = this->layer_param_.convolution_param().group(); num_ = bottom[0]->num(); channels_ = bottom[0]->channels(); height_ = bottom[0]->height(); @@ -61,28 +61,33 @@ void ConvolutionLayer::LayerSetUp(const vector*>& bottom, CHECK_EQ(width_, bottom[bottom_id]->width()) << "Inputs must have same width."; } + // Configure output channels, groups, and spatial dimensions. num_output_ = this->layer_param_.convolution_param().num_output(); CHECK_GT(num_output_, 0); + group_ = this->layer_param_.convolution_param().group(); CHECK_EQ(channels_ % group_, 0); - // The im2col result buffer would only hold one image at a time to avoid - // overly large memory usage. + CHECK_EQ(num_output_ % group_, 0) + << "Number of output should be multiples of group."; height_out_ = (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1; width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1; - col_buffer_.Reshape( - 1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_); - // Set the parameters - CHECK_EQ(num_output_ % group_, 0) - << "Number of output should be multiples of group."; - bias_term_ = this->layer_param_.convolution_param().bias_term(); - // Figure out the dimensions for individual gemms. - M_ = num_output_ / group_; - K_ = channels_ * kernel_h_ * kernel_w_ / group_; - N_ = height_out_ * width_out_; for (int top_id = 0; top_id < top->size(); ++top_id) { (*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_); } - // Check if we need to set up the weights + // Prepare the matrix multiplication computation. + // Each input will be convolved as a single GEMM. + M_ = num_output_ / group_; + K_ = channels_ * kernel_h_ * kernel_w_ / group_; + N_ = height_out_ * width_out_; + // The im2col result buffer holds one image at a time to avoid + // overly large memory usage. + col_buffer_.Reshape( + 1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_); + // Handle the parameters: weights and biases. + // - blobs_[0] holds the filter weights + // - blobs_[1] holds the biases (optional) + bias_term_ = this->layer_param_.convolution_param().bias_term(); + // Check if we need to set up the weights. if (this->blobs_.size() > 0) { LOG(INFO) << "Skipping parameter initialization"; } else { @@ -91,14 +96,15 @@ void ConvolutionLayer::LayerSetUp(const vector*>& bottom, } else { this->blobs_.resize(1); } - // Intialize the weight + // Initialize and fill the weights: + // output channels x input channels per-group x kernel height x kernel width this->blobs_[0].reset(new Blob( num_output_, channels_ / group_, kernel_h_, kernel_w_)); - // fill the weights shared_ptr > weight_filler(GetFiller( this->layer_param_.convolution_param().weight_filler())); weight_filler->Fill(this->blobs_[0].get()); - // If necessary, initialize and fill the bias term + // If necessary, initialize and fill the biases: + // 1 x 1 x 1 x output channels if (bias_term_) { this->blobs_[1].reset(new Blob(1, 1, 1, num_output_)); shared_ptr > bias_filler(GetFiller( @@ -106,11 +112,12 @@ void ConvolutionLayer::LayerSetUp(const vector*>& bottom, bias_filler->Fill(this->blobs_[1].get()); } } - // Set up the all ones "bias multiplier" for adding bias using blas + // Set up the all ones "bias multiplier" for adding biases by BLAS if (bias_term_) { bias_multiplier_.Reshape(1, 1, 1, N_); caffe_set(N_, Dtype(1), bias_multiplier_.mutable_cpu_data()); } + // Propagate gradients to the parameters (as directed by backward pass). this->param_propagate_down_.resize(this->blobs_.size(), true); } @@ -123,21 +130,22 @@ void ConvolutionLayer::Forward_cpu(const vector*>& bottom, Dtype* top_data = (*top)[i]->mutable_cpu_data(); Dtype* col_data = col_buffer_.mutable_cpu_data(); const Dtype* weight = this->blobs_[0]->cpu_data(); - int weight_offset = M_ * K_; - int col_offset = K_ * N_; - int top_offset = M_ * N_; + int weight_offset = M_ * K_; // number of filter parameters in a group + int col_offset = K_ * N_; // number of values in an input region / column + int top_offset = M_ * N_; // number of values in an output region / column for (int n = 0; n < num_; ++n) { - // First, im2col + // im2col transformation: unroll input regions for filtering + // into column matrix for multplication. im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_, width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_data); - // Second, innerproduct with groups + // Take inner products for groups. for (int g = 0; g < group_; ++g) { caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1., weight + weight_offset * g, col_data + col_offset * g, (Dtype)0., top_data + (*top)[i]->offset(n) + top_offset * g); } - // third, add bias + // Add bias. if (bias_term_) { caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, N_, 1, (Dtype)1., this->blobs_[1]->cpu_data(), @@ -201,7 +209,7 @@ void ConvolutionLayer::Backward_cpu(const vector*>& top, weight_diff + weight_offset * g); } } - // gradient w.r.t. bottom data, if necessary + // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { if (weight == NULL) { weight = this->blobs_[0]->cpu_data(); diff --git a/src/caffe/layers/conv_layer.cu b/src/caffe/layers/conv_layer.cu index 02cfdb3b80c..43f76a2368f 100644 --- a/src/caffe/layers/conv_layer.cu +++ b/src/caffe/layers/conv_layer.cu @@ -8,6 +8,7 @@ namespace caffe { +/// @brief refer to CPU forward -- the BLAS implementation is the same. template void ConvolutionLayer::Forward_gpu(const vector*>& bottom, vector*>* top) { @@ -20,17 +21,18 @@ void ConvolutionLayer::Forward_gpu(const vector*>& bottom, int col_offset = K_ * N_; int top_offset = M_ * N_; for (int n = 0; n < num_; ++n) { - // First, im2col + // im2col transformation: unroll input regions for filtering + // into column matrix for multplication. im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_, width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_data); - // Second, innerproduct with groups + // Take inner products for groups. for (int g = 0; g < group_; ++g) { caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1., weight + weight_offset * g, col_data + col_offset * g, (Dtype)0., top_data + (*top)[i]->offset(n) + top_offset * g); } - // third, add bias + // Add bias. if (bias_term_) { caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(), @@ -41,6 +43,7 @@ void ConvolutionLayer::Forward_gpu(const vector*>& bottom, } } +/// @brief refer to CPU backward -- the BLAS implementation is the same. template void ConvolutionLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, vector*>* bottom) {