diff --git a/examples/mnist/lenet_solver_rmsprop.prototxt b/examples/mnist/lenet_solver_rmsprop.prototxt new file mode 100644 index 00000000000..74dadc51069 --- /dev/null +++ b/examples/mnist/lenet_solver_rmsprop.prototxt @@ -0,0 +1,27 @@ +# The train/test net protocol buffer definition +net: "examples/mnist/lenet_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.01 +momentum: 0.0 +weight_decay: 0.0005 +# The learning rate policy +lr_policy: "inv" +gamma: 0.0001 +power: 0.75 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet_rmsprop" +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: RMSPROP +rms_decay: 0.98 diff --git a/examples/mnist/train_lenet_rmsprop.sh b/examples/mnist/train_lenet_rmsprop.sh new file mode 100755 index 00000000000..621cab238bf --- /dev/null +++ b/examples/mnist/train_lenet_rmsprop.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh + +./build/tools/caffe train --solver=examples/mnist/lenet_solver_rmsprop.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 84a45a8cdab..1cd7b7bbd4c 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -135,6 +135,28 @@ class AdaGradSolver : public SGDSolver { DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; +template +class RMSPropSolver : public SGDSolver { + public: + explicit RMSPropSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit RMSPropSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with RMSProp."; + CHECK_GE(this->param_.rms_decay(), 0) + << "rms_decay should lie between 0 and 1."; + CHECK_LT(this->param_.rms_decay(), 1) + << "rms_decay should lie between 0 and 1."; + } + + DISABLE_COPY_AND_ASSIGN(RMSPropSolver); +}; + template class AdaDeltaSolver : public SGDSolver { public: @@ -169,6 +191,8 @@ Solver* GetSolver(const SolverParameter& param) { return new AdaGradSolver(param); case SolverParameter_SolverType_ADADELTA: return new AdaDeltaSolver(param); + case SolverParameter_SolverType_RMSPROP: + return new RMSPropSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index e2073a2304c..7cfcaa8bac7 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 38 (last added: snapshot_format) +// SolverParameter next available ID: 39 (last added: rms_decay) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -153,7 +153,23 @@ message SolverParameter { optional int32 max_iter = 7; // the maximum number of iterations // accumulate gradients over `iter_size` x `batch_size` instances optional int32 iter_size = 36 [default = 1]; - optional string lr_policy = 8; // The learning rate decay policy. + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; optional float gamma = 9; // The parameter to compute the learning rate. optional float power = 10; // The parameter to compute the learning rate. optional float momentum = 11; // The momentum value. @@ -198,12 +214,17 @@ message SolverParameter { SGD = 0; NESTEROV = 1; ADAGRAD = 2; - ADADELTA = 3; + RMSPROP = 3; + ADADELTA = 4; } optional SolverType solver_type = 30 [default = SGD]; // numerical stability for AdaGrad optional float delta = 31 [default = 1e-8]; + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38; + // If true, print information about the state of the net that may help with // debugging learning problems. optional bool debug_info = 23 [default = false]; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 2326693454d..72ff574ccaa 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -1042,10 +1042,86 @@ void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { } } +template +void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector > >& net_params = this->net_->params(); + const vector& net_params_lr = this->net_->params_lr(); + + // get the learning rate + Dtype delta = this->param_.delta(); + Dtype rms_decay = this->param_.rms_decay(); + Dtype local_rate = rate * net_params_lr[param_id]; + + switch (Caffe::mode()) { + case Caffe::CPU: + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_cpu_axpby(net_params[param_id] -> count(), + Dtype(1-rms_decay), this->update_[param_id]->cpu_data(), + rms_decay, this->history_[param_id]-> mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + break; + case Caffe::GPU: +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_axpby(net_params[param_id] -> count(), + Dtype(1-rms_decay), this->update_[param_id]->gpu_data(), + rms_decay, this->history_[param_id]-> mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); INSTANTIATE_CLASS(NesterovSolver); INSTANTIATE_CLASS(AdaGradSolver); INSTANTIATE_CLASS(AdaDeltaSolver); +INSTANTIATE_CLASS(RMSPropSolver); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 1a91c41251c..754e8a64eed 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -52,8 +52,9 @@ class GradientBasedSolverTest : public MultiDeviceTest { LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); } InitSolver(param); - delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD) - || (solver_type() == SolverParameter_SolverType_ADADELTA) ? + delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD || + solver_type() == SolverParameter_SolverType_ADADELTA || + solver_type() == SolverParameter_SolverType_RMSPROP) ? param.delta() : 0; } @@ -307,6 +308,12 @@ class GradientBasedSolverTest : public MultiDeviceTest { // momentum * update_history_value + (1 - momentum) * (update_value); break; } + case SolverParameter_SolverType_RMSPROP: { + const Dtype rms_decay = 0.95; + update_value /= std::sqrt(rms_decay*history_value + + grad * grad * (1 - rms_decay)) + delta_; + } + break; default: LOG(FATAL) << "Unknown solver type: " << solver_type(); } @@ -435,11 +442,11 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Compute the (K+1)th update using the analytic least squares gradient. vector > > updated_params; ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - &updated_params); + &updated_params); // Reinitialize the solver and run K+1 solver iterations. RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - iter_to_check + 1); + iter_to_check + 1); // Check that the solver's solution matches ours. CheckLeastSquaresUpdate(updated_params); @@ -453,7 +460,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { bool snapshot = false; const int kIterSize = 1; RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - total_num_iters, kIterSize, snapshot); + total_num_iters, kIterSize, snapshot); // Save the resulting param values. vector > > param_copies; @@ -487,8 +494,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Reinitialize the solver and run for num_iters more iterations. snapshot = false; - RunLeastSquaresSolver(learning_rate, weight_decay, - momentum, total_num_iters, kIterSize, snapshot, snapshot_name.c_str()); + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + total_num_iters, kIterSize, snapshot, snapshot_name.c_str()); // Check that params now match. const vector*>& params = solver_->net()->learnable_params(); @@ -699,7 +706,7 @@ TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) { } TYPED_TEST(AdaGradSolverTest, - TestAdaGradLeastSquaresUpdateWithEverythingShare) { + TestAdaGradLeastSquaresUpdateWithEverythingShare) { typedef typename TypeParam::Dtype Dtype; const Dtype kLearningRate = 0.01; const Dtype kWeightDecay = 0.5; @@ -968,4 +975,110 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { } } +template +class RMSPropSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + const Dtype rms_decay = 0.95; + SolverParameter new_param = param; + new_param.set_rms_decay(rms_decay); + this->solver_.reset(new RMSPropSolver(new_param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_RMSPROP; + } +}; + +TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices); + +TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.0; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.0; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(RMSPropSolverTest, + TestRMSPropLeastSquaresUpdateWithEverythingShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.0; + const int kNumIters = 4; + this->share_ = true; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.0; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.0; + const int kNumIters = 4; + const int kIterSize = 2; + this->share_ = true; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(RMSPropSolverTest, TestSnapshot) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0; + const int kNumIters = 4; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0; + const int kNumIters = 4; + this->share_ = true; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + } // namespace caffe