Skip to content

Commit

Permalink
Updated AdaDelta for modern Caffe; reduced iterations on multi-iter t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
kevinbache authored and matthiasplappert committed Aug 10, 2015
1 parent 1ce3380 commit 4c58741
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples/mnist/mnist_autoencoder_solver_adadelta.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ test_iter: 100
test_interval: 500
test_compute_loss: true
momentum: 0.95
delta: 1e-8
display: 100
max_iter: 65000
weight_decay: 0.0005
Expand All @@ -14,4 +15,3 @@ snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: ADADELTA
delta: 1e-8
6 changes: 3 additions & 3 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ class SGDSolver : public Solver<Dtype> {
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

protected:
void PreSolve();
Dtype GetLearningRate();
virtual void ApplyUpdate();
virtual void Normalize(int param_id);
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void PreSolve();
virtual void ClipGradients();
virtual void SnapshotSolverState(const string& model_filename);
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
Expand Down Expand Up @@ -162,9 +162,9 @@ template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> {
public:
explicit AdaDeltaSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
: SGDSolver<Dtype>(param) { PreSolve(); constructor_sanity_check(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
: SGDSolver<Dtype>(param_file) { PreSolve(); constructor_sanity_check(); }

protected:
virtual void PreSolve();
Expand Down
32 changes: 9 additions & 23 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -936,35 +936,21 @@ void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {

template <typename Dtype>
void AdaDeltaSolver<Dtype>::PreSolve() {
// Initialize the history
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
this->history_.clear();
this->update_.clear();
this->temp_.clear();
for (int i = 0; i < net_params.size(); ++i) {
const Blob<Dtype>* net_param = net_params[i].get();
this->history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
net_param->num(), net_param->channels(), net_param->height(),
net_param->width())));
this->update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
net_param->num(), net_param->channels(), net_param->height(),
net_param->width())));
this->temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
net_param->num(), net_param->channels(), net_param->height(),
net_param->width())));
}
// Add the extra history entries for AdaDelta after those from
// SGDSolver::PreSolve
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
for (int i = 0; i < net_params.size(); ++i) {
const Blob<Dtype>* net_param = net_params[i].get();
this->history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
net_param->num(), net_param->channels(), net_param->height(),
net_param->width())));
const vector<int>& shape = net_params[i]->shape();
this->history_.push_back(
shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
}
}

template <typename Dtype>
void AdaDeltaSolver<Dtype>::ComputeUpdateValue() {
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
Dtype delta = this->param_.delta();
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
const Dtype kLearningRate = 0.0;
const Dtype kWeightDecay = 0.0;
const Dtype kMomentum = 0.95;
const int kNumIters = 500;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
}
Expand All @@ -1071,7 +1071,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
const Dtype kLearningRate = 0.0;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.95;
const int kNumIters = 500;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
}
Expand Down

0 comments on commit 4c58741

Please sign in to comment.