Skip to content

Commit

Permalink
Merge pull request BVLC#227 from shelhamer/padding-deprecation
Browse files Browse the repository at this point in the history
Bring back padding layer to ease release upgrade
  • Loading branch information
longjon committed Mar 18, 2014
2 parents 2fec8cf + 996d996 commit efcb9c5
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 0 deletions.
28 changes: 28 additions & 0 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,34 @@ class InnerProductLayer : public Layer<Dtype> {
shared_ptr<SyncedMemory> bias_multiplier_;
};


template <typename Dtype>
class PaddingLayer : public Layer<Dtype> {
public:
explicit PaddingLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
unsigned int PAD_;
int NUM_;
int CHANNEL_;
int HEIGHT_IN_;
int WIDTH_IN_;
int HEIGHT_OUT_;
int WIDTH_OUT_;
};


template <typename Dtype>
class LRNLayer : public Layer<Dtype> {
public:
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new InnerProductLayer<Dtype>(param);
} else if (type == "lrn") {
return new LRNLayer<Dtype>(param);
} else if (type == "padding") {
return new PaddingLayer<Dtype>(param);
} else if (type == "pool") {
return new PoolingLayer<Dtype>(param);
} else if (type == "relu") {
Expand Down
74 changes: 74 additions & 0 deletions src/caffe/layers/padding_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2013 Yangqing Jia

#include <iostream> // NOLINT(readability/streams)
#include <vector>

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {

template <typename Dtype>
void PaddingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// DEPRECATION
LOG(WARNING) << "Padding layers are deprecated in favor of padding-aware "
"convolutions and WILL BE REMOVED. Please update your model "
"prototxt to replace padding layers with pad fields. "
"See https://github.com/BVLC/caffe/pull/128.";
PAD_ = this->layer_param_.pad();
CHECK_EQ(bottom.size(), 1) << "Padding Layer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "Padding Layer takes a single blob as output.";
NUM_ = bottom[0]->num();
CHANNEL_ = bottom[0]->channels();
HEIGHT_IN_ = bottom[0]->height();
WIDTH_IN_ = bottom[0]->width();
HEIGHT_OUT_ = HEIGHT_IN_ + PAD_ * 2;
WIDTH_OUT_ = WIDTH_IN_ + PAD_ * 2;
(*top)[0]->Reshape(NUM_, CHANNEL_, HEIGHT_OUT_, WIDTH_OUT_);
}

template <typename Dtype>
void PaddingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const Dtype* bottom_data = bottom[0]->cpu_data();
memset(top_data, 0, sizeof(Dtype) * (*top)[0]->count());
// In short, top[n, c, h, w] = bottom[n, c, h-pad, w-pad] if in range
for (int n = 0; n < NUM_; ++n) {
for (int c = 0; c < CHANNEL_; ++c) {
for (int h = 0; h < HEIGHT_IN_; ++h) {
// copy the width part
memcpy(
top_data + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
* WIDTH_OUT_ + PAD_,
bottom_data + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
sizeof(Dtype) * WIDTH_IN_);
}
}
}
}

template <typename Dtype>
Dtype PaddingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
for (int n = 0; n < NUM_; ++n) {
for (int c = 0; c < CHANNEL_; ++c) {
for (int h = 0; h < HEIGHT_IN_; ++h) {
// copy the width part
memcpy(
bottom_diff + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
top_diff + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
* WIDTH_OUT_ + PAD_,
sizeof(Dtype) * WIDTH_IN_);
}
}
}
return Dtype(0.);
}

INSTANTIATE_CLASS(PaddingLayer);

} // namespace caffe
84 changes: 84 additions & 0 deletions src/caffe/layers/padding_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2013 Yangqing Jia

#include <iostream> // NOLINT(readability/streams)
#include <vector>

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {

template <typename Dtype>
__global__ void PaddingForward(const int count, const Dtype* in, Dtype* out,
const int num, const int channel, const int height_in, const int width_in,
const int pad) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < count) {
int height_out = height_in + pad + pad;
int width_out = width_in + pad + pad;
int w = index % width_in;
index /= width_in;
int h = index % height_in;
index /= height_in;
int c = index % channel;
index /= channel;
out[((index * channel + c) * height_out + h + pad) * width_out + pad + w] =
in[((index * channel + c) * height_in + h) * width_in + w];
}
}

template <typename Dtype>
void PaddingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
// First, set all data to be zero for the boundary pixels
CUDA_CHECK(cudaMemset(top_data, 0, sizeof(Dtype) * (*top)[0]->count()));
// NOLINT_NEXT_LINE(whitespace/operators)
PaddingForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
PAD_);
CUDA_POST_KERNEL_CHECK;
}

template <typename Dtype>
__global__ void PaddingBackward(const int count, const Dtype* in, Dtype* out,
const int num, const int channel, const int height_in, const int width_in,
const int pad) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < count) {
int height_out = height_in + pad + pad;
int width_out = width_in + pad + pad;
int w = index % width_in;
index /= width_in;
int h = index % height_in;
index /= height_in;
int c = index % channel;
index /= channel;
out[((index * channel + c) * height_in + h) * width_in + w] =
in[((index * channel + c) * height_out + h + pad) *
width_out + pad + w];
}
}

template <typename Dtype>
Dtype PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
// NOLINT_NEXT_LINE(whitespace/operators)
PaddingBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
PAD_);
CUDA_POST_KERNEL_CHECK;
}
return Dtype(0);
}

INSTANTIATE_CLASS(PaddingLayer);

} // namespace caffe
117 changes: 117 additions & 0 deletions src/caffe/test/test_padding_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright 2013 Yangqing Jia

#include <cuda_runtime.h>
#include <cstring>
#include <vector>

#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/test/test_gradient_check_util.hpp"

#include "caffe/test/test_caffe_main.hpp"

namespace caffe {

extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;

template <typename Dtype>
class PaddingLayerTest : public ::testing::Test {
protected:
PaddingLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
blob_top_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
}
virtual ~PaddingLayerTest() { delete blob_bottom_; delete blob_top_; }
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};

typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(PaddingLayerTest, Dtypes);

TYPED_TEST(PaddingLayerTest, TestCPU) {
LayerParameter layer_param;
layer_param.set_pad(1);
Caffe::set_mode(Caffe::CPU);
PaddingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 2);
EXPECT_EQ(this->blob_top_->channels(), 3);
EXPECT_EQ(this->blob_top_->height(), 6);
EXPECT_EQ(this->blob_top_->width(), 7);
for (int n = 0; n < 2; ++n) {
for (int c = 0; c < 3; ++c) {
for (int h = 0; h < 4; ++h) {
for (int w = 0; w < 5; ++w) {
EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w),
this->blob_top_->data_at(n, c, h + 1, w + 1));
}
}
}
}
}

TYPED_TEST(PaddingLayerTest, TestCPUGrad) {
LayerParameter layer_param;
layer_param.set_pad(1);
Caffe::set_mode(Caffe::CPU);
PaddingLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
}

TYPED_TEST(PaddingLayerTest, TestGPU) {
if (CAFFE_TEST_CUDA_PROP.major >= 2) {
LayerParameter layer_param;
layer_param.set_pad(1);
Caffe::set_mode(Caffe::GPU);
PaddingLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 2);
EXPECT_EQ(this->blob_top_->channels(), 3);
EXPECT_EQ(this->blob_top_->height(), 6);
EXPECT_EQ(this->blob_top_->width(), 7);
for (int n = 0; n < 2; ++n) {
for (int c = 0; c < 3; ++c) {
for (int h = 0; h < 4; ++h) {
for (int w = 0; w < 5; ++w) {
EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w),
this->blob_top_->data_at(n, c, h + 1, w + 1));
}
}
}
}
} else {
LOG(ERROR) << "Skipping test (gpu version too low).";
}
}

TYPED_TEST(PaddingLayerTest, TestGPUGrad) {
if (CAFFE_TEST_CUDA_PROP.major >= 2) {
LayerParameter layer_param;
layer_param.set_pad(1);
Caffe::set_mode(Caffe::GPU);
PaddingLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
&(this->blob_top_vec_));
} else {
LOG(ERROR) << "Skipping test (gpu version too low).";
}
}

} // namespace caffe

0 comments on commit efcb9c5

Please sign in to comment.