Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/BVLC/caffe into adadelta
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasplappert committed Aug 9, 2015
2 parents a099482 + 28a579e commit 1c62c83
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 11 deletions.
27 changes: 27 additions & 0 deletions examples/mnist/lenet_solver_rmsprop.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.0
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_rmsprop"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: RMSPROP
rms_decay: 0.98
3 changes: 3 additions & 0 deletions examples/mnist/train_lenet_rmsprop.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env sh

./build/tools/caffe train --solver=examples/mnist/lenet_solver_rmsprop.prototxt
24 changes: 24 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,28 @@ class AdaGradSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};

template <typename Dtype>
class RMSPropSolver : public SGDSolver<Dtype> {
public:
explicit RMSPropSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit RMSPropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }

protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with RMSProp.";
CHECK_GE(this->param_.rms_decay(), 0)
<< "rms_decay should lie between 0 and 1.";
CHECK_LT(this->param_.rms_decay(), 1)
<< "rms_decay should lie between 0 and 1.";
}

DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};

template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> {
public:
Expand Down Expand Up @@ -169,6 +191,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new AdaGradSolver<Dtype>(param);
case SolverParameter_SolverType_ADADELTA:
return new AdaDeltaSolver<Dtype>(param);
case SolverParameter_SolverType_RMSPROP:
return new RMSPropSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
Expand Down
27 changes: 24 additions & 3 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 38 (last added: snapshot_format)
// SolverParameter next available ID: 39 (last added: rms_decay)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -153,7 +153,23 @@ message SolverParameter {
optional int32 max_iter = 7; // the maximum number of iterations
// accumulate gradients over `iter_size` x `batch_size` instances
optional int32 iter_size = 36 [default = 1];
optional string lr_policy = 8; // The learning rate decay policy.

// The learning rate decay policy. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
optional string lr_policy = 8;
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value.
Expand Down Expand Up @@ -198,12 +214,17 @@ message SolverParameter {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
ADADELTA = 3;
RMSPROP = 3;
ADADELTA = 4;
}
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
optional float delta = 31 [default = 1e-8];

// RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
optional float rms_decay = 38;

// If true, print information about the state of the net that may help with
// debugging learning problems.
optional bool debug_info = 23 [default = false];
Expand Down
76 changes: 76 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1042,10 +1042,86 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
}

template <typename Dtype>
void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_lr = this->net_->params_lr();

// get the learning rate
Dtype delta = this->param_.delta();
Dtype rms_decay = this->param_.rms_decay();
Dtype local_rate = rate * net_params_lr[param_id];

switch (Caffe::mode()) {
case Caffe::CPU:
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// update history
caffe_cpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
rms_decay, this->history_[param_id]-> mutable_cpu_data());

// prepare update
caffe_powx(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());

caffe_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_cpu_data());

caffe_div(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

// scale and copy
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->cpu_data(), Dtype(0),
net_params[param_id]->mutable_cpu_diff());
break;
case Caffe::GPU:
#ifndef CPU_ONLY
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history
caffe_gpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
rms_decay, this->history_[param_id]-> mutable_gpu_data());

// prepare update
caffe_gpu_powx(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_gpu_data());

caffe_gpu_div(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->gpu_data(), Dtype(0),
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
INSTANTIATE_CLASS(AdaDeltaSolver);
INSTANTIATE_CLASS(RMSPropSolver);

} // namespace caffe
129 changes: 121 additions & 8 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
}
InitSolver(param);
delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD)
|| (solver_type() == SolverParameter_SolverType_ADADELTA) ?
delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD ||
solver_type() == SolverParameter_SolverType_ADADELTA ||
solver_type() == SolverParameter_SolverType_RMSPROP) ?
param.delta() : 0;
}

Expand Down Expand Up @@ -307,6 +308,12 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
// momentum * update_history_value + (1 - momentum) * (update_value);
break;
}
case SolverParameter_SolverType_RMSPROP: {
const Dtype rms_decay = 0.95;
update_value /= std::sqrt(rms_decay*history_value
+ grad * grad * (1 - rms_decay)) + delta_;
}
break;
default:
LOG(FATAL) << "Unknown solver type: " << solver_type();
}
Expand Down Expand Up @@ -435,11 +442,11 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
// Compute the (K+1)th update using the analytic least squares gradient.
vector<shared_ptr<Blob<Dtype> > > updated_params;
ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
&updated_params);
&updated_params);

// Reinitialize the solver and run K+1 solver iterations.
RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
iter_to_check + 1);
iter_to_check + 1);

// Check that the solver's solution matches ours.
CheckLeastSquaresUpdate(updated_params);
Expand All @@ -453,7 +460,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
bool snapshot = false;
const int kIterSize = 1;
RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
total_num_iters, kIterSize, snapshot);
total_num_iters, kIterSize, snapshot);

// Save the resulting param values.
vector<shared_ptr<Blob<Dtype> > > param_copies;
Expand Down Expand Up @@ -487,8 +494,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {

// Reinitialize the solver and run for num_iters more iterations.
snapshot = false;
RunLeastSquaresSolver(learning_rate, weight_decay,
momentum, total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
total_num_iters, kIterSize, snapshot, snapshot_name.c_str());

// Check that params now match.
const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
Expand Down Expand Up @@ -699,7 +706,7 @@ TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
}

TYPED_TEST(AdaGradSolverTest,
TestAdaGradLeastSquaresUpdateWithEverythingShare) {
TestAdaGradLeastSquaresUpdateWithEverythingShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
Expand Down Expand Up @@ -968,4 +975,110 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
}
}

template <typename TypeParam>
class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;

protected:
virtual void InitSolver(const SolverParameter& param) {
const Dtype rms_decay = 0.95;
SolverParameter new_param = param;
new_param.set_rms_decay(rms_decay);
this->solver_.reset(new RMSPropSolver<Dtype>(new_param));
}
virtual SolverParameter_SolverType solver_type() {
return SolverParameter_SolverType_RMSPROP;
}
};

TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);

TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.5;
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
}

TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.0;
const Dtype kMomentum = 0.0;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.0;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(RMSPropSolverTest,
TestRMSPropLeastSquaresUpdateWithEverythingShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.0;
const int kNumIters = 4;
this->share_ = true;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.0;
const int kNumIters = 4;
const int kIterSize = 2;
this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
kIterSize);
}

TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.0;
const int kNumIters = 4;
const int kIterSize = 2;
this->share_ = true;
this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
kIterSize);
}

TYPED_TEST(RMSPropSolverTest, TestSnapshot) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0;
const int kNumIters = 4;
for (int i = 1; i <= kNumIters; ++i) {
this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0;
const int kNumIters = 4;
this->share_ = true;
for (int i = 1; i <= kNumIters; ++i) {
this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
}
}

} // namespace caffe

0 comments on commit 1c62c83

Please sign in to comment.