Skip to content

Commit

Permalink
Merge pull request #1104 from shelhamer/conv-comments-tests
Browse files Browse the repository at this point in the history
Document and Test Convolution
  • Loading branch information
shelhamer committed Sep 18, 2014
2 parents c3a69b7 + 18ca362 commit 8dac339
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 115 deletions.
62 changes: 59 additions & 3 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,49 @@ namespace caffe {
* @brief Convolves the input image with a bank of learned filters,
* and (optionally) adds biases.
*
* TODO(dox): thorough documentation for Forward, Backward, and proto params.
* Caffe convolves by reduction to matrix multiplication. This achieves
* high-throughput and generality of input and filter dimensions but comes at
* the cost of memory for matrices. This makes use of efficiency in BLAS.
*
* The input is "im2col" transformed to a channel K' x H x W data matrix
* for multiplication with the N x K' x H x W filter matrix to yield a
* N' x H x W output matrix that is then "col2im" restored. K' is the
* input channel * kernel height * kernel width dimension of the unrolled
* inputs so that the im2col matrix has a column for each input region to
* be filtered. col2im restores the output spatial structure by rolling up
* the output channel N' columns of the output matrix.
*/
template <typename Dtype>
class ConvolutionLayer : public Layer<Dtype> {
public:
/**
* @param param provides ConvolutionParameter convolution_param,
* with ConvolutionLayer options:
* - num_output. The number of filters.
* - kernel_size / kernel_h / kernel_w. The filter dimensions, given by
* kernel_size for square filters or kernel_h and kernel_w for rectangular
* filters.
* - stride / stride_h / stride_w (\b optional, default 1). The filter
* stride, given by stride_size for equal dimensions or stride_h and stride_w
* for different strides. By default the convolution is dense with stride 1.
* - pad / pad_h / pad_w (\b optional, default 0). The zero-padding for
* convolution, given by pad for equal dimensions or pad_h and pad_w for
* different padding. Input padding is computed implicitly instead of
* actually padding.
* - group (\b optional, default 1). The number of filter groups. Group
* convolution is a method for reducing parameterization by selectively
* connecting input and output channels. The input and output channel dimensions must be divisible
* by the number of groups. For group @f$ \geq 1 @f$, the
* convolutional filters' input and output channels are separated s.t. each
* group takes 1 / group of the input channels and makes 1 / group of the
* output channels. Concretely 4 input channels, 8 output channels, and
* 2 groups separate input channels 1-2 and output channels 1-4 into the
* first group and input channels 3-4 and output channels 5-8 into the second
* group.
* - bias_term (\b optional, default true). Whether to have a bias.
* - engine: convolution has CAFFE (matrix multiplication) and CUDNN (library
* kernels + stream parallelism) engines.
*/
explicit ConvolutionLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
Expand Down Expand Up @@ -57,8 +95,16 @@ class ConvolutionLayer : public Layer<Dtype> {
int num_output_;
int height_out_, width_out_;
bool bias_term_;
// For the Caffe matrix multiplication convolution.
int M_, K_, N_;

/// M_ is the channel dimension of the output for a single group, which is the
/// leading dimension of the filter matrix.
int M_;
/// K_ is the dimension of an unrolled input for a single group, which is the
/// leading dimension of the data matrix.
int K_;
/// N_ is the spatial dimension of the output, the H x W, which are the last
/// dimensions of the data and filter matrices.
int N_;
Blob<Dtype> col_buffer_;
Blob<Dtype> bias_multiplier_;
};
Expand All @@ -67,6 +113,16 @@ class ConvolutionLayer : public Layer<Dtype> {
/*
* @brief cuDNN implementation of ConvolutionLayer.
* Fallback to ConvolutionLayer for CPU mode.
*
* cuDNN accelerates convolution through forward kernels for filtering and bias
* plus backward kernels for the gradient w.r.t. the filters, biases, and
* inputs. Caffe + cuDNN further speeds up the computation through forward
* parallelism across groups and backward parallelism across gradients.
*
* The CUDNN engine does not have memory overhead for matrix buffers. For many
* input and filter regimes the CUDNN engine is faster than the CAFFE engine,
* but for fully-convolutional models and large inputs the CAFFE engine can be
* faster as long as it fits in memory.
*/
template <typename Dtype>
class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
Expand Down
58 changes: 33 additions & 25 deletions src/caffe/layers/conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace caffe {
template <typename Dtype>
void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// Configure the kernel size, padding, stride, and inputs.
ConvolutionParameter conv_param = this->layer_param_.convolution_param();
CHECK(!conv_param.has_kernel_size() !=
!(conv_param.has_kernel_h() && conv_param.has_kernel_w()))
Expand Down Expand Up @@ -46,7 +47,6 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
stride_h_ = conv_param.stride_h();
stride_w_ = conv_param.stride_w();
}
group_ = this->layer_param_.convolution_param().group();
num_ = bottom[0]->num();
channels_ = bottom[0]->channels();
height_ = bottom[0]->height();
Expand All @@ -61,28 +61,33 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_EQ(width_, bottom[bottom_id]->width())
<< "Inputs must have same width.";
}
// Configure output channels, groups, and spatial dimensions.
num_output_ = this->layer_param_.convolution_param().num_output();
CHECK_GT(num_output_, 0);
group_ = this->layer_param_.convolution_param().group();
CHECK_EQ(channels_ % group_, 0);
// The im2col result buffer would only hold one image at a time to avoid
// overly large memory usage.
CHECK_EQ(num_output_ % group_, 0)
<< "Number of output should be multiples of group.";
height_out_ =
(height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
col_buffer_.Reshape(
1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_);
// Set the parameters
CHECK_EQ(num_output_ % group_, 0)
<< "Number of output should be multiples of group.";
bias_term_ = this->layer_param_.convolution_param().bias_term();
// Figure out the dimensions for individual gemms.
M_ = num_output_ / group_;
K_ = channels_ * kernel_h_ * kernel_w_ / group_;
N_ = height_out_ * width_out_;
for (int top_id = 0; top_id < top->size(); ++top_id) {
(*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
}
// Check if we need to set up the weights
// Prepare the matrix multiplication computation.
// Each input will be convolved as a single GEMM.
M_ = num_output_ / group_;
K_ = channels_ * kernel_h_ * kernel_w_ / group_;
N_ = height_out_ * width_out_;
// The im2col result buffer holds one image at a time to avoid
// overly large memory usage.
col_buffer_.Reshape(
1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_);
// Handle the parameters: weights and biases.
// - blobs_[0] holds the filter weights
// - blobs_[1] holds the biases (optional)
bias_term_ = this->layer_param_.convolution_param().bias_term();
// Check if we need to set up the weights.
if (this->blobs_.size() > 0) {
LOG(INFO) << "Skipping parameter initialization";
} else {
Expand All @@ -91,26 +96,28 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
} else {
this->blobs_.resize(1);
}
// Intialize the weight
// Initialize and fill the weights:
// output channels x input channels per-group x kernel height x kernel width
this->blobs_[0].reset(new Blob<Dtype>(
num_output_, channels_ / group_, kernel_h_, kernel_w_));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.convolution_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
// If necessary, initialize and fill the bias term
// If necessary, initialize and fill the biases:
// 1 x 1 x 1 x output channels
if (bias_term_) {
this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, num_output_));
shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
this->layer_param_.convolution_param().bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
}
// Set up the all ones "bias multiplier" for adding bias using blas
// Set up the all ones "bias multiplier" for adding biases by BLAS
if (bias_term_) {
bias_multiplier_.Reshape(1, 1, 1, N_);
caffe_set(N_, Dtype(1), bias_multiplier_.mutable_cpu_data());
}
// Propagate gradients to the parameters (as directed by backward pass).
this->param_propagate_down_.resize(this->blobs_.size(), true);
}

Expand All @@ -123,21 +130,22 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
Dtype* top_data = (*top)[i]->mutable_cpu_data();
Dtype* col_data = col_buffer_.mutable_cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
int weight_offset = M_ * K_; // number of filter parameters in a group
int col_offset = K_ * N_; // number of values in an input region / column
int top_offset = M_ * N_; // number of values in an output region / column
for (int n = 0; n < num_; ++n) {
// First, im2col
// im2col transformation: unroll input regions for filtering
// into column matrix for multplication.
im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_,
width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
col_data);
// Second, innerproduct with groups
// Take inner products for groups.
for (int g = 0; g < group_; ++g) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
(Dtype)1., weight + weight_offset * g, col_data + col_offset * g,
(Dtype)0., top_data + (*top)[i]->offset(n) + top_offset * g);
}
// third, add bias
// Add bias.
if (bias_term_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_,
N_, 1, (Dtype)1., this->blobs_[1]->cpu_data(),
Expand Down Expand Up @@ -201,7 +209,7 @@ void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
weight_diff + weight_offset * g);
}
}
// gradient w.r.t. bottom data, if necessary
// gradient w.r.t. bottom data, if necessary.
if (propagate_down[i]) {
if (weight == NULL) {
weight = this->blobs_[0]->cpu_data();
Expand Down
9 changes: 6 additions & 3 deletions src/caffe/layers/conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

namespace caffe {

/// @brief refer to CPU forward -- the BLAS implementation is the same.
template <typename Dtype>
void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Expand All @@ -20,17 +21,18 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
int col_offset = K_ * N_;
int top_offset = M_ * N_;
for (int n = 0; n < num_; ++n) {
// First, im2col
// im2col transformation: unroll input regions for filtering
// into column matrix for multplication.
im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_,
width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
col_data);
// Second, innerproduct with groups
// Take inner products for groups.
for (int g = 0; g < group_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
(Dtype)1., weight + weight_offset * g, col_data + col_offset * g,
(Dtype)0., top_data + (*top)[i]->offset(n) + top_offset * g);
}
// third, add bias
// Add bias.
if (bias_term_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_,
N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(),
Expand All @@ -41,6 +43,7 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
}
}

/// @brief refer to CPU backward -- the BLAS implementation is the same.
template <typename Dtype>
void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
Expand Down
Loading

0 comments on commit 8dac339

Please sign in to comment.