Skip to content

Commit

Permalink
Clean up and modernize AdaDelta code; add learning rate support; add …
Browse files Browse the repository at this point in the history
…additional test cases
  • Loading branch information
matthiasplappert committed Aug 10, 2015
1 parent 4c58741 commit f2e523e
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 245 deletions.
2 changes: 2 additions & 0 deletions examples/mnist/lenet_adadelta_solver.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 1.0
lr_policy: "fixed"
momentum: 0.95
weight_decay: 0.0005
# Display every 100 iterations
Expand Down
2 changes: 2 additions & 0 deletions examples/mnist/mnist_autoencoder_solver_adadelta.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ test_state: { stage: 'test-on-test' }
test_iter: 100
test_interval: 500
test_compute_loss: true
base_lr: 1.0
lr_policy: "fixed"
momentum: 0.95
delta: 1e-8
display: 100
Expand Down
16 changes: 5 additions & 11 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ class SGDSolver : public Solver<Dtype> {
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

protected:
void PreSolve();
Dtype GetLearningRate();
virtual void ApplyUpdate();
virtual void Normalize(int param_id);
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void PreSolve();
virtual void ClipGradients();
virtual void SnapshotSolverState(const string& model_filename);
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
Expand Down Expand Up @@ -162,19 +162,13 @@ template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> {
public:
explicit AdaDeltaSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { PreSolve(); constructor_sanity_check(); }
: SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { PreSolve(); constructor_sanity_check(); }
: SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }

protected:
virtual void PreSolve();
virtual void ComputeUpdateValue();
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.base_lr())
<< "Learning rate cannot be used with AdaDelta.";
CHECK_EQ("", this->param_.lr_policy())
<< "Learning rate policy cannot be applied to AdaDelta.";
}
void AdaDeltaPreSolve();
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};
Expand Down
274 changes: 118 additions & 156 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,10 +935,10 @@ void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}

template <typename Dtype>
void AdaDeltaSolver<Dtype>::PreSolve() {
void AdaDeltaSolver<Dtype>::AdaDeltaPreSolve() {
// Add the extra history entries for AdaDelta after those from
// SGDSolver::PreSolve
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
for (int i = 0; i < net_params.size(); ++i) {
const vector<int>& shape = net_params[i]->shape();
this->history_.push_back(
Expand All @@ -947,172 +947,134 @@ void AdaDeltaSolver<Dtype>::PreSolve() {
}

template <typename Dtype>
void AdaDeltaSolver<Dtype>::ComputeUpdateValue() {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype delta = this->param_.delta();
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
Dtype local_rate = rate * net_params_lr[param_id];
size_t update_history_offset = net_params.size();
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
case Caffe::CPU: {
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// update history of gradients
caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->cpu_data(), momentum,
this->history_[param_id]->mutable_cpu_data());

// add delta to history to guard against dividing by zero later
caffe_set(net_params[param_id]->count(), delta,
this->temp_[param_id]->mutable_cpu_data());

caffe_add(net_params[param_id]->count(),
this->temp_[param_id]->cpu_data(),
this->history_[update_history_offset + param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

caffe_add(net_params[param_id]->count(),
this->temp_[param_id]->cpu_data(),
this->history_[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());

// divide history of updates by history of gradients
caffe_div(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(),
this->temp_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

// jointly compute the RMS of both for update and gradient history
caffe_powx(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());

// compute the update
caffe_mul(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(),
this->update_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());

// compute square of update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// update history of updates
caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->cpu_data(), momentum,
this->history_[update_history_offset + param_id]->mutable_cpu_data());
}
// update history of gradients
caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->cpu_data(), momentum,
this->history_[param_id]->mutable_cpu_data());

// add delta to history to guard against dividing by zero later
caffe_set(net_params[param_id]->count(), delta,
this->temp_[param_id]->mutable_cpu_data());

caffe_add(net_params[param_id]->count(),
this->temp_[param_id]->cpu_data(),
this->history_[update_history_offset + param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

caffe_add(net_params[param_id]->count(),
this->temp_[param_id]->cpu_data(),
this->history_[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());

// divide history of updates by history of gradients
caffe_div(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(),
this->temp_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

// jointly compute the RMS of both for update and gradient history
caffe_powx(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());

// compute the update
caffe_mul(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(),
this->update_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());

// compute square of update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// update history of updates
caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->cpu_data(), momentum,
this->history_[update_history_offset + param_id]->mutable_cpu_data());

// apply learning rate
caffe_cpu_scale(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(),
net_params[param_id]->mutable_cpu_diff());
break;
case Caffe::GPU:
}
case Caffe::GPU: {
#ifndef CPU_ONLY
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history of gradients
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->gpu_data(), momentum,
this->history_[param_id]->mutable_gpu_data());

// add delta to history to guard against dividing by zero later
caffe_gpu_set(net_params[param_id]->count(), delta,
this->temp_[param_id]->mutable_gpu_data());

caffe_gpu_add(net_params[param_id]->count(),
this->temp_[param_id]->gpu_data(),
this->history_[update_history_offset + param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_add(net_params[param_id]->count(),
this->temp_[param_id]->gpu_data(),
this->history_[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());

// divide history of updates by history of gradients
caffe_gpu_div(net_params[param_id]->count(),
this->update_[param_id]->gpu_data(),
this->temp_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

// jointly compute the RMS of both for update and gradient history
caffe_gpu_powx(net_params[param_id]->count(),
this->update_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());

// compute the update and copy to net_diff
caffe_gpu_mul(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(),
this->update_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());

// compute square of update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history of updates
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->gpu_data(), momentum,
this->history_[update_history_offset + param_id]->mutable_gpu_data());
}
// update history of gradients
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->gpu_data(), momentum,
this->history_[param_id]->mutable_gpu_data());

// add delta to history to guard against dividing by zero later
caffe_gpu_set(net_params[param_id]->count(), delta,
this->temp_[param_id]->mutable_gpu_data());

caffe_gpu_add(net_params[param_id]->count(),
this->temp_[param_id]->gpu_data(),
this->history_[update_history_offset + param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_add(net_params[param_id]->count(),
this->temp_[param_id]->gpu_data(),
this->history_[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());

// divide history of updates by history of gradients
caffe_gpu_div(net_params[param_id]->count(),
this->update_[param_id]->gpu_data(),
this->temp_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

// jointly compute the RMS of both for update and gradient history
caffe_gpu_powx(net_params[param_id]->count(),
this->update_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());

// compute the update and copy to net_diff
caffe_gpu_mul(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(),
this->update_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());

// compute square of update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history of updates
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->gpu_data(), momentum,
this->history_[update_history_offset + param_id]->mutable_gpu_data());

// apply learning rate
caffe_gpu_scale(net_params[param_id]->count(), local_rate,
net_params[param_id]->gpu_diff(),
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
Expand Down
Loading

0 comments on commit f2e523e

Please sign in to comment.