diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 3f204ab00d5..e2337e575bd 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -161,13 +161,12 @@ template class AdaDeltaSolver : public SGDSolver { public: explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver(param) { PreSolve(); } + : SGDSolver(param) { AdaDeltaPreSolve(); } explicit AdaDeltaSolver(const string& param_file) - : SGDSolver(param_file) { PreSolve(); } + : SGDSolver(param_file) { AdaDeltaPreSolve(); } protected: - void PreSolve(); - virtual void Regularize(int param_id); + void AdaDeltaPreSolve(); virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 2247718ba79..bee0cd9749d 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -860,10 +860,10 @@ void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { } template -void AdaDeltaSolver::PreSolve() { +void AdaDeltaSolver::AdaDeltaPreSolve() { // Add the extra history entries for AdaDelta after those from // SGDSolver::PreSolve - const vector > >& net_params = this->net_->params(); + const vector*>& net_params = this->net_->learnable_params(); for (int i = 0; i < net_params.size(); ++i) { const vector& shape = net_params[i]->shape(); this->history_.push_back( @@ -871,71 +871,9 @@ void AdaDeltaSolver::PreSolve() { } } -template -void AdaDeltaSolver::Regularize(int param_id) { - const vector > >& net_params = this->net_->params(); - const vector& net_params_weight_decay = - this->net_->params_weight_decay(); - Dtype weight_decay = this->param_.weight_decay(); - string regularization_type = this->param_.regularization_type(); - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - switch (Caffe::mode()) { - case Caffe::CPU: { - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else if (regularization_type == "L1") { - caffe_cpu_sign(net_params[param_id]->count(), - net_params[param_id]->cpu_data(), - this->temp_[param_id]->mutable_cpu_data()); - caffe_axpy(net_params[param_id]->count(), - local_decay, - this->temp_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else if (regularization_type == "L1") { - caffe_gpu_sign(net_params[param_id]->count(), - net_params[param_id]->gpu_data(), - this->temp_[param_id]->mutable_gpu_data()); - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - this->temp_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - template void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector > >& net_params = this->net_->params(); + const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype delta = this->param_.delta(); Dtype momentum = this->param_.momentum(); diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 63d1a08ea14..d65e82c0b85 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -973,6 +973,18 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { } } +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverythingShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + this->share_ = true; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { typedef typename TypeParam::Dtype Dtype; const Dtype kLearningRate = 1.0; @@ -984,6 +996,41 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { kIterSize); } +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + const int kIterSize = 2; + this->share_ = true; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 4; + this->share_ = true; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + template class RMSPropSolverTest : public GradientBasedSolverTest { typedef typename TypeParam::Dtype Dtype;