Skip to content

Commit

Permalink
Merge pull request #1977 from shelhamer/accum-grad
Browse files Browse the repository at this point in the history
Decouple the computational batch size and minibatch size by accumulating gradients
  • Loading branch information
shelhamer committed May 30, 2015
2 parents 8b05a02 + 0e7a078 commit aeef453
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 41 deletions.
1 change: 1 addition & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class SGDSolver : public Solver<Dtype> {
void PreSolve();
Dtype GetLearningRate();
virtual void ApplyUpdate();
virtual void Normalize(int param_id);
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void ClipGradients();
Expand Down
7 changes: 5 additions & 2 deletions include/caffe/test/test_gradient_check_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,14 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>* layer,
CHECK_EQ(top_count, bottom[blob_id]->count());
}
}
// First, figure out what blobs we need to check against.
// First, figure out what blobs we need to check against, and zero init
// parameter blobs.
vector<Blob<Dtype>*> blobs_to_check;
vector<bool> propagate_down(bottom.size(), check_bottom < 0);
for (int i = 0; i < layer->blobs().size(); ++i) {
blobs_to_check.push_back(layer->blobs()[i].get());
Blob<Dtype>* blob = layer->blobs()[i].get();
caffe_set(blob->count(), static_cast<Dtype>(0), blob->mutable_cpu_diff());
blobs_to_check.push_back(blob);
}
if (check_bottom < 0) {
for (int i = 0; i < bottom.size(); ++i) {
Expand Down
7 changes: 0 additions & 7 deletions src/caffe/layers/conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* weight = this->blobs_[0]->cpu_data();
Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
if (this->param_propagate_down_[0]) {
caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
}
if (this->bias_term_ && this->param_propagate_down_[1]) {
caffe_set(this->blobs_[1]->count(), Dtype(0),
this->blobs_[1]->mutable_cpu_diff());
}
for (int i = 0; i < top.size(); ++i) {
const Dtype* top_diff = top[i]->cpu_diff();
const Dtype* bottom_data = bottom[i]->cpu_data();
Expand Down
7 changes: 0 additions & 7 deletions src/caffe/layers/conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* weight = this->blobs_[0]->gpu_data();
Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
if (this->param_propagate_down_[0]) {
caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
}
if (this->bias_term_ && this->param_propagate_down_[1]) {
caffe_gpu_set(this->blobs_[1]->count(), Dtype(0),
this->blobs_[1]->mutable_gpu_diff());
}
for (int i = 0; i < top.size(); ++i) {
const Dtype* top_diff = top[i]->gpu_diff();
// Bias gradient, if necessary.
Expand Down
2 changes: 0 additions & 2 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,10 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
if (this->param_propagate_down_[0]) {
weight = this->blobs_[0]->gpu_data();
weight_diff = this->blobs_[0]->mutable_gpu_diff();
caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
}
Dtype* bias_diff = NULL;
if (this->bias_term_ && this->param_propagate_down_[1]) {
bias_diff = this->blobs_[1]->mutable_gpu_diff();
caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
}
for (int i = 0; i < top.size(); ++i) {
const Dtype* top_diff = top[i]->gpu_diff();
Expand Down
7 changes: 0 additions & 7 deletions src/caffe/layers/deconv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ void DeconvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* weight = this->blobs_[0]->cpu_data();
Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
if (this->param_propagate_down_[0]) {
caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
}
if (this->bias_term_ && this->param_propagate_down_[1]) {
caffe_set(this->blobs_[1]->count(), Dtype(0),
this->blobs_[1]->mutable_cpu_diff());
}
for (int i = 0; i < top.size(); ++i) {
const Dtype* top_diff = top[i]->cpu_diff();
const Dtype* bottom_data = bottom[i]->cpu_data();
Expand Down
7 changes: 0 additions & 7 deletions src/caffe/layers/deconv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ void DeconvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* weight = this->blobs_[0]->gpu_data();
Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
if (this->param_propagate_down_[0]) {
caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
}
if (this->bias_term_ && this->param_propagate_down_[1]) {
caffe_gpu_set(this->blobs_[1]->count(), Dtype(0),
this->blobs_[1]->mutable_gpu_diff());
}
for (int i = 0; i < top.size(); ++i) {
const Dtype* top_diff = top[i]->gpu_diff();
const Dtype* bottom_data = bottom[i]->gpu_data();
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/inner_product_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_cpu_diff());
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
// Gradient with respect to bias
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
bias_multiplier_.cpu_data(), (Dtype)0.,
bias_multiplier_.cpu_data(), (Dtype)1.,
this->blobs_[1]->mutable_cpu_diff());
}
if (propagate_down[0]) {
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/inner_product_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
// Gradient with respect to weight
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_gpu_diff());
top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_gpu_diff());
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->gpu_diff();
// Gradient with respect to bias
caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
bias_multiplier_.gpu_data(), (Dtype)0.,
bias_multiplier_.gpu_data(), (Dtype)1.,
this->blobs_[1]->mutable_gpu_diff());
}
if (propagate_down[0]) {
Expand Down
4 changes: 3 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 36 (last added: clip_gradients)
// SolverParameter next available ID: 37 (last added: iter_size)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -149,6 +149,8 @@ message SolverParameter {
// Display the loss averaged over the last average_loss iterations
optional int32 average_loss = 33 [default = 1];
optional int32 max_iter = 7; // the maximum number of iterations
// accumulate gradients over `iter_size` x `batch_size` instances
optional int32 iter_size = 36 [default = 1];
optional string lr_policy = 8; // The learning rate decay policy.
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
Expand Down
54 changes: 53 additions & 1 deletion src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,39 @@ void Solver<Dtype>::Step(int iters) {
Dtype smoothed_loss = 0;

while (iter_ < stop_iter) {
// zero-init the params
for (int i = 0; i < net_->params().size(); ++i) {
shared_ptr<Blob<Dtype> > blob = net_->params()[i];
switch (Caffe::mode()) {
case Caffe::CPU:
caffe_set(blob->count(), static_cast<Dtype>(0),
blob->mutable_cpu_diff());
break;
case Caffe::GPU:
#ifndef CPU_ONLY
caffe_gpu_set(blob->count(), static_cast<Dtype>(0),
blob->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
}

if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())) {
TestAll();
}

const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info());
Dtype loss = net_->ForwardBackward(bottom_vec);
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward(bottom_vec);
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
if (losses.size() < average_loss) {
losses.push_back(loss);
int size = losses.size();
Expand Down Expand Up @@ -462,12 +487,39 @@ void SGDSolver<Dtype>::ApplyUpdate() {
}
ClipGradients();
for (int param_id = 0; param_id < this->net_->params().size(); ++param_id) {
Normalize(param_id);
Regularize(param_id);
ComputeUpdateValue(param_id, rate);
}
this->net_->Update();
}

template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
if (this->param_.iter_size() == 1) { return; }
// Scale gradient to counterbalance accumulation.
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
Expand Down
82 changes: 79 additions & 3 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {

protected:
GradientBasedSolverTest() :
seed_(1701), num_(5), channels_(3), height_(10), width_(10) {}
seed_(1701), num_(4), channels_(3), height_(10), width_(10) {}

shared_ptr<SGDSolver<Dtype> > solver_;
int seed_;
Expand Down Expand Up @@ -56,26 +56,32 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
}

void RunLeastSquaresSolver(const Dtype learning_rate,
const Dtype weight_decay, const Dtype momentum, const int num_iters) {
const Dtype weight_decay, const Dtype momentum, const int num_iters,
const int iter_size = 1) {
ostringstream proto;
proto <<
"max_iter: " << num_iters << " "
"base_lr: " << learning_rate << " "
"lr_policy: 'fixed' "
"iter_size: " << iter_size << " "
"net_param { "
" name: 'TestNetwork' "
" layer { "
" name: 'data' "
" type: 'DummyData' "
" dummy_data_param { "
" num: " << num_ << " "
" num: " << num_ / iter_size << " "
" channels: " << channels_ << " "
" height: " << height_ << " "
" width: " << width_ << " "
" channels: 1 "
" height: 1 "
" width: 1 "
" data_filler { "
" type: 'constant' "
" value: 1.0 "
" } "
" data_filler { "
" type: 'gaussian' "
" std: 1.0 "
" } "
Expand Down Expand Up @@ -270,6 +276,45 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
}
}

void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay,
const Dtype kMomentum, const int kNumIters, const int kIterSize) {
const double kPrecision = 1e-2;
const double kMinPrecision = 1e-7;
// Solve without accumulation and save parameters.
this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
kNumIters);
// Save parameters for comparison.
Net<Dtype>& net = *this->solver_->net();
const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
net.layer_by_name("innerprod")->blobs();
vector<shared_ptr<Blob<Dtype> > > noaccum_params(param_blobs.size());
for (int i = 0; i < param_blobs.size(); ++i) {
noaccum_params[i].reset(new Blob<Dtype>());
noaccum_params[i]->CopyFrom(*param_blobs[i], false, true);
}
// Solve by equivalent accumulation of gradients over divided batches.
this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
kNumIters, kIterSize);
Net<Dtype>& net_accum = *this->solver_->net();
const vector<shared_ptr<Blob<Dtype> > >& accum_params =
net_accum.layer_by_name("innerprod")->blobs();
// Compare accumulated parameters against no accumulation standard.
const int D = this->channels_ * this->height_ * this->width_;
for (int i = 0; i < D; ++i) {
const Dtype expected_param = noaccum_params[0]->cpu_data()[i];
const Dtype accum_param = accum_params[0]->cpu_data()[i];
const Dtype error_margin = std::max(kMinPrecision, kPrecision *
std::min(fabs(expected_param), fabs(accum_param)));
EXPECT_NEAR(expected_param, accum_param, error_margin);
}
ASSERT_EQ(1, accum_params[1]->count());
const Dtype expected_bias = noaccum_params[1]->cpu_data()[0];
const Dtype accum_bias = accum_params[1]->cpu_data()[0];
const Dtype error_margin = std::max(kMinPrecision, kPrecision *
std::min(fabs(expected_bias), fabs(accum_bias)));
EXPECT_NEAR(expected_bias, accum_bias, error_margin);
}

// Test that the correct update is computed for a regularized least squares
// problem:
//
Expand Down Expand Up @@ -372,6 +417,16 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) {
}
}

TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.9;
const int kNumIters = 4;
const int kIterSize = 2;
this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
kIterSize);
}

template <typename TypeParam>
class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
Expand Down Expand Up @@ -416,6 +471,16 @@ TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
}
}

TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.0;
const int kNumIters = 4;
const int kIterSize = 2;
this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
kIterSize);
}

template <typename TypeParam>
class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
Expand Down Expand Up @@ -482,4 +547,15 @@ TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) {
}
}

TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.9;
const int kNumIters = 4;
const int kIterSize = 2;
this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
kIterSize);
}

} // namespace caffe

0 comments on commit aeef453

Please sign in to comment.