Skip to content

Commit

Permalink
Adam solver
Browse files Browse the repository at this point in the history
This commit implements the Adam solver by Kingma et. al for CPU and
GPU. All solver parameters are defined in the caffe.proto. This also
adds an example for the MNIST dataset.
  • Loading branch information
PatWie authored and ronghanghu committed Aug 14, 2015
1 parent bb0a90e commit 9d60495
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 28 deletions.
26 changes: 26 additions & 0 deletions examples/mnist/lenet_solver_adam.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# The train/test net protocol buffer definition
# this follows "ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION"
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# All parameters are from the cited paper above
base_lr: 0.001
momentum: 0.9
momentum2: 0.999
# since Adam dynamically changes the learning rate, we set the base learning
# rate to a fixed value
lr_policy: "fixed"
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_type: ADAM
solver_mode: GPU
3 changes: 3 additions & 0 deletions examples/mnist/train_lenet_adam.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env sh

./build/tools/caffe train --solver=examples/mnist/lenet_solver_adam.prototxt
17 changes: 17 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ class AdaDeltaSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};

template <typename Dtype>
class AdamSolver : public SGDSolver<Dtype> {
public:
explicit AdamSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { AdamPreSolve();}
explicit AdamSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { AdamPreSolve(); }

protected:
void AdamPreSolve();
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(AdamSolver);
};

template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
Expand All @@ -232,6 +247,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new RMSPropSolver<Dtype>(param);
case SolverParameter_SolverType_ADADELTA:
return new AdaDeltaSolver<Dtype>(param);
case SolverParameter_SolverType_ADAM:
return new AdamSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
Expand Down
7 changes: 5 additions & 2 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 39 (last added: rms_decay)
// SolverParameter next available ID: 40 (last added: momentum2)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -216,10 +216,13 @@ message SolverParameter {
ADAGRAD = 2;
RMSPROP = 3;
ADADELTA = 4;
ADAM = 5;
}
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
// numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
optional float delta = 31 [default = 1e-8];
// parameters for the Adam solver
optional float momentum2 = 39 [default = 0.999];

// RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
Expand Down
105 changes: 105 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,11 +1114,116 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
}

template <typename Dtype>
void AdamSolver<Dtype>::AdamPreSolve() {
// Add the extra history entries for Adam after those from
// SGDSolver::PreSolve
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
for (int i = 0; i < net_params.size(); ++i) {
const vector<int>& shape = net_params[i]->shape();
this->history_.push_back(
shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
}
}

template <typename Dtype>
void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype local_rate = rate * net_params_lr[param_id];
const Dtype beta1 = this->param_.momentum();
const Dtype beta2 = this->param_.momentum2();

// we create aliases for convenience
size_t update_history_offset = net_params.size();
shared_ptr<Blob<Dtype> > val_m = this->history_[param_id];
shared_ptr<Blob<Dtype> > val_v =
this->history_[param_id + update_history_offset];
shared_ptr<Blob<Dtype> > val_t = this->temp_[param_id];

const int t = this->iter_ + 1;
const Dtype correction = std::sqrt(Dtype(1)-pow(beta2, t))/
(Dtype(1.)-pow(beta1, t));
const int N = net_params[param_id]->count();
const Dtype eps_hat = this->param_.delta();

switch (Caffe::mode()) {
case Caffe::CPU: {
// update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
caffe_cpu_axpby(N, Dtype(1)-beta1,
net_params[param_id]->cpu_diff(), beta1,
val_m->mutable_cpu_data());

// update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
caffe_mul(N,
net_params[param_id]->cpu_diff(),
net_params[param_id]->cpu_diff(),
val_t->mutable_cpu_data());
caffe_cpu_axpby(N, Dtype(1)-beta2,
val_t->cpu_data(), beta2,
val_v->mutable_cpu_data());

// set update
caffe_powx(N,
val_v->cpu_data(), Dtype(0.5),
val_t->mutable_cpu_data());
caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
caffe_div(N,
val_m->cpu_data(),
val_t->cpu_data(),
val_t->mutable_cpu_data());

caffe_cpu_scale(N, local_rate*correction,
val_t->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
// update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
caffe_gpu_axpby(N, Dtype(1)-beta1,
net_params[param_id]->gpu_diff(), beta1,
val_m->mutable_gpu_data());

// update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
caffe_gpu_mul(N,
net_params[param_id]->gpu_diff(),
net_params[param_id]->gpu_diff(),
val_t->mutable_gpu_data());
caffe_gpu_axpby(N, Dtype(1)-beta2,
val_t->gpu_data(), beta2,
val_v->mutable_gpu_data());

// set update
caffe_gpu_powx(N,
val_v->gpu_data(), Dtype(0.5),
val_t->mutable_gpu_data());
caffe_gpu_add_scalar(N, eps_hat,
val_t->mutable_gpu_data());
caffe_gpu_div(N,
val_m->gpu_data(),
val_t->gpu_data(),
val_t->mutable_gpu_data());

caffe_gpu_scale(N, local_rate*correction,
val_t->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
INSTANTIATE_CLASS(RMSPropSolver);
INSTANTIATE_CLASS(AdaDeltaSolver);
INSTANTIATE_CLASS(AdamSolver);

} // namespace caffe
Loading

0 comments on commit 9d60495

Please sign in to comment.