Skip to content

Commit

Permalink
Move regularization into appropriate method
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasplappert committed Aug 9, 2015
1 parent 1c62c83 commit fc35518
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
1 change: 1 addition & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class AdaDeltaSolver : public SGDSolver<Dtype> {

protected:
virtual void PreSolve();
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.base_lr())
Expand Down
70 changes: 42 additions & 28 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,19 +873,15 @@ void AdaDeltaSolver<Dtype>::PreSolve() {
}

template <typename Dtype>
void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
void AdaDeltaSolver<Dtype>::Regularize(int param_id) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
Dtype delta = this->param_.delta();
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
size_t update_history_offset = net_params.size();
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
switch (Caffe::mode()) {
case Caffe::CPU: {
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
Expand All @@ -905,7 +901,47 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

template <typename Dtype>
void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
Dtype delta = this->param_.delta();
Dtype momentum = this->param_.momentum();
size_t update_history_offset = net_params.size();
switch (Caffe::mode()) {
case Caffe::CPU: {
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
Expand Down Expand Up @@ -960,28 +996,6 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
case Caffe::GPU: {
#ifndef CPU_ONLY
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
Expand Down

0 comments on commit fc35518

Please sign in to comment.