diff --git a/examples/mnist/lenet_adadelta_solver.prototxt b/examples/mnist/lenet_adadelta_solver.prototxt index b77b451d56a..776d1e06139 100644 --- a/examples/mnist/lenet_adadelta_solver.prototxt +++ b/examples/mnist/lenet_adadelta_solver.prototxt @@ -7,6 +7,8 @@ test_iter: 100 # Carry out testing every 500 training iterations. test_interval: 500 # The base learning rate, momentum and the weight decay of the network. +base_lr: 1.0 +lr_policy: "fixed" momentum: 0.95 weight_decay: 0.0005 # Display every 100 iterations diff --git a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt index 4e43468a71f..065647df31b 100644 --- a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt @@ -5,6 +5,8 @@ test_state: { stage: 'test-on-test' } test_iter: 100 test_interval: 500 test_compute_loss: true +base_lr: 1.0 +lr_policy: "fixed" momentum: 0.95 delta: 1e-8 display: 100 diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 8ca0588e821..3f204ab00d5 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -161,20 +161,14 @@ template class AdaDeltaSolver : public SGDSolver { public: explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver(param) { PreSolve(); constructor_sanity_check(); } + : SGDSolver(param) { PreSolve(); } explicit AdaDeltaSolver(const string& param_file) - : SGDSolver(param_file) { PreSolve(); constructor_sanity_check(); } + : SGDSolver(param_file) { PreSolve(); } protected: void PreSolve(); virtual void Regularize(int param_id); virtual void ComputeUpdateValue(int param_id, Dtype rate); - void constructor_sanity_check() { - CHECK_EQ(0, this->param_.base_lr()) - << "Learning rate cannot be used with AdaDelta."; - CHECK_EQ("", this->param_.lr_policy()) - << "Learning rate policy cannot be applied to AdaDelta."; - } DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index a0c78a52009..2247718ba79 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -437,8 +437,7 @@ Dtype SGDSolver::GetLearningRate() { (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - Dtype(this->param_.stepsize()))))); } else { - rate = Dtype(0.); - //LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; + LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; } return rate; } @@ -937,8 +936,10 @@ void AdaDeltaSolver::Regularize(int param_id) { template void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { const vector > >& net_params = this->net_->params(); + const vector& net_params_lr = this->net_->params_lr(); Dtype delta = this->param_.delta(); Dtype momentum = this->param_.momentum(); + Dtype local_rate = rate * net_params_lr[param_id]; size_t update_history_offset = net_params.size(); switch (Caffe::mode()) { case Caffe::CPU: { @@ -992,6 +993,11 @@ void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, this->update_[param_id]->cpu_data(), momentum, this->history_[update_history_offset + param_id]->mutable_cpu_data()); + + // apply learning rate + caffe_cpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), + net_params[param_id]->mutable_cpu_diff()); break; } case Caffe::GPU: { @@ -1046,6 +1052,11 @@ void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, this->update_[param_id]->gpu_data(), momentum, this->history_[update_history_offset + param_id]->mutable_gpu_data()); + + // apply learning rate + caffe_gpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), + net_params[param_id]->mutable_gpu_diff()); #else NO_GPU; #endif diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 754e8a64eed..b654295ccf9 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -66,6 +66,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { proto << "snapshot_after_train: " << snapshot << " " "max_iter: " << num_iters << " " + "base_lr: " << learning_rate << " " + "lr_policy: 'fixed' " "iter_size: " << iter_size << " " "net_param { " " name: 'TestNetwork' " @@ -167,10 +169,6 @@ class GradientBasedSolverTest : public MultiDeviceTest { " bottom: 'targets' " " } " "} "; - if (learning_rate != 0) { - proto << "base_lr: " << learning_rate << " "; - proto << "lr_policy: 'fixed' "; - } if (weight_decay != 0) { proto << "weight_decay: " << weight_decay << " "; } @@ -919,13 +917,13 @@ TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; + const Dtype kLearningRate = 1.0; this->TestLeastSquaresUpdate(kLearningRate); } TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; + const Dtype kLearningRate = 1.0; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.95; this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); @@ -933,7 +931,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; + const Dtype kLearningRate = 1.0; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.5; const int kNumIters = 1; @@ -944,7 +942,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) { TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; + const Dtype kLearningRate = 1.0; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.95; const int kNumIters = 1; @@ -955,7 +953,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) { TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; + const Dtype kLearningRate = 1.0; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.95; const int kNumIters = 4; @@ -966,7 +964,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.0; + const Dtype kLearningRate = 1.0; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; const int kNumIters = 4;