Skip to content

Commit

Permalink
Add support for learning rate to AdaDelta
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasplappert committed Aug 9, 2015
1 parent 3c341e6 commit 322a9de
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 20 deletions.
2 changes: 2 additions & 0 deletions examples/mnist/lenet_adadelta_solver.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 1.0
lr_policy: "fixed"
momentum: 0.95
weight_decay: 0.0005
# Display every 100 iterations
Expand Down
2 changes: 2 additions & 0 deletions examples/mnist/mnist_autoencoder_solver_adadelta.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ test_state: { stage: 'test-on-test' }
test_iter: 100
test_interval: 500
test_compute_loss: true
base_lr: 1.0
lr_policy: "fixed"
momentum: 0.95
delta: 1e-8
display: 100
Expand Down
10 changes: 2 additions & 8 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,20 +161,14 @@ template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> {
public:
explicit AdaDeltaSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { PreSolve(); constructor_sanity_check(); }
: SGDSolver<Dtype>(param) { PreSolve(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { PreSolve(); constructor_sanity_check(); }
: SGDSolver<Dtype>(param_file) { PreSolve(); }

protected:
void PreSolve();
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.base_lr())
<< "Learning rate cannot be used with AdaDelta.";
CHECK_EQ("", this->param_.lr_policy())
<< "Learning rate policy cannot be applied to AdaDelta.";
}

DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};
Expand Down
15 changes: 13 additions & 2 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,7 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));
} else {
rate = Dtype(0.);
//LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
return rate;
}
Expand Down Expand Up @@ -937,8 +936,10 @@ void AdaDeltaSolver<Dtype>::Regularize(int param_id) {
template <typename Dtype>
void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype delta = this->param_.delta();
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
size_t update_history_offset = net_params.size();
switch (Caffe::mode()) {
case Caffe::CPU: {
Expand Down Expand Up @@ -992,6 +993,11 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->cpu_data(), momentum,
this->history_[update_history_offset + param_id]->mutable_cpu_data());

// apply learning rate
caffe_cpu_scale(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
Expand Down Expand Up @@ -1046,6 +1052,11 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->gpu_data(), momentum,
this->history_[update_history_offset + param_id]->mutable_gpu_data());

// apply learning rate
caffe_gpu_scale(net_params[param_id]->count(), local_rate,
net_params[param_id]->gpu_diff(),
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
Expand Down
18 changes: 8 additions & 10 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
proto <<
"snapshot_after_train: " << snapshot << " "
"max_iter: " << num_iters << " "
"base_lr: " << learning_rate << " "
"lr_policy: 'fixed' "
"iter_size: " << iter_size << " "
"net_param { "
" name: 'TestNetwork' "
Expand Down Expand Up @@ -167,10 +169,6 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
" bottom: 'targets' "
" } "
"} ";
if (learning_rate != 0) {
proto << "base_lr: " << learning_rate << " ";
proto << "lr_policy: 'fixed' ";
}
if (weight_decay != 0) {
proto << "weight_decay: " << weight_decay << " ";
}
Expand Down Expand Up @@ -919,21 +917,21 @@ TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices);

TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.0;
const Dtype kLearningRate = 1.0;
this->TestLeastSquaresUpdate(kLearningRate);
}

TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.0;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.95;
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
}

TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.0;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.0;
const Dtype kMomentum = 0.5;
const int kNumIters = 1;
Expand All @@ -944,7 +942,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) {

TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.0;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.0;
const Dtype kMomentum = 0.95;
const int kNumIters = 1;
Expand All @@ -955,7 +953,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) {

TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.0;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.0;
const Dtype kMomentum = 0.95;
const int kNumIters = 4;
Expand All @@ -966,7 +964,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {

TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.0;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.95;
const int kNumIters = 4;
Expand Down

0 comments on commit 322a9de

Please sign in to comment.