diff --git a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt index cc4f0bbb4a7..4e43468a71f 100644 --- a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt @@ -6,6 +6,7 @@ test_iter: 100 test_interval: 500 test_compute_loss: true momentum: 0.95 +delta: 1e-8 display: 100 max_iter: 65000 weight_decay: 0.0005 @@ -14,4 +15,3 @@ snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train" # solver mode: CPU or GPU solver_mode: GPU solver_type: ADADELTA -delta: 1e-8 diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 4b408380119..495cd4f159e 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -82,12 +82,12 @@ class SGDSolver : public Solver { const vector > >& history() { return history_; } protected: - void PreSolve(); Dtype GetLearningRate(); virtual void ApplyUpdate(); virtual void Normalize(int param_id); virtual void Regularize(int param_id); virtual void ComputeUpdateValue(int param_id, Dtype rate); + virtual void PreSolve(); virtual void ClipGradients(); virtual void SnapshotSolverState(const string& model_filename); virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); @@ -162,9 +162,9 @@ template class AdaDeltaSolver : public SGDSolver { public: explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver(param) { constructor_sanity_check(); } + : SGDSolver(param) { PreSolve(); constructor_sanity_check(); } explicit AdaDeltaSolver(const string& param_file) - : SGDSolver(param_file) { constructor_sanity_check(); } + : SGDSolver(param_file) { PreSolve(); constructor_sanity_check(); } protected: virtual void PreSolve(); diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index d8749a1b939..34a290ffe3d 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -936,35 +936,21 @@ void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { template void AdaDeltaSolver::PreSolve() { - // Initialize the history - vector > >& net_params = this->net_->params(); - this->history_.clear(); - this->update_.clear(); - this->temp_.clear(); - for (int i = 0; i < net_params.size(); ++i) { - const Blob* net_param = net_params[i].get(); - this->history_.push_back(shared_ptr >(new Blob( - net_param->num(), net_param->channels(), net_param->height(), - net_param->width()))); - this->update_.push_back(shared_ptr >(new Blob( - net_param->num(), net_param->channels(), net_param->height(), - net_param->width()))); - this->temp_.push_back(shared_ptr >(new Blob( - net_param->num(), net_param->channels(), net_param->height(), - net_param->width()))); - } + // Add the extra history entries for AdaDelta after those from + // SGDSolver::PreSolve + const vector > >& net_params = this->net_->params(); for (int i = 0; i < net_params.size(); ++i) { - const Blob* net_param = net_params[i].get(); - this->history_.push_back(shared_ptr >(new Blob( - net_param->num(), net_param->channels(), net_param->height(), - net_param->width()))); + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); } } template void AdaDeltaSolver::ComputeUpdateValue() { - vector > >& net_params = this->net_->params(); - vector& net_params_weight_decay = this->net_->params_weight_decay(); + const vector > >& net_params = this->net_->params(); + const vector& net_params_weight_decay = + this->net_->params_weight_decay(); Dtype delta = this->param_.delta(); Dtype momentum = this->param_.momentum(); Dtype weight_decay = this->param_.weight_decay(); diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index db89e285a9f..277aa3a5c8e 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -1060,7 +1060,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { const Dtype kLearningRate = 0.0; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.95; - const int kNumIters = 500; + const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } @@ -1071,7 +1071,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { const Dtype kLearningRate = 0.0; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; - const int kNumIters = 500; + const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); }