Skip to content

Commit

Permalink
Speeding up the GPU solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
philkr committed Jan 6, 2016
1 parent b086cc3 commit 6d09ca2
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 161 deletions.
66 changes: 11 additions & 55 deletions src/caffe/solvers/adadelta_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ void AdaDeltaSolver<Dtype>::AdaDeltaPreSolve() {
}
}

#ifndef CPU_ONLY
template <typename Dtype>
void adadelta_update_gpu(int N, Dtype* g, Dtype* h, Dtype* h2, Dtype momentum,
Dtype delta, Dtype local_rate);
#endif

template <typename Dtype>
void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
Expand Down Expand Up @@ -85,61 +91,11 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
case Caffe::GPU: {
#ifndef CPU_ONLY
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history of gradients
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->gpu_data(), momentum,
this->history_[param_id]->mutable_gpu_data());

// add delta to history to guard against dividing by zero later
caffe_gpu_set(net_params[param_id]->count(), delta,
this->temp_[param_id]->mutable_gpu_data());

caffe_gpu_add(net_params[param_id]->count(),
this->temp_[param_id]->gpu_data(),
this->history_[update_history_offset + param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_add(net_params[param_id]->count(),
this->temp_[param_id]->gpu_data(),
this->history_[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());

// divide history of updates by history of gradients
caffe_gpu_div(net_params[param_id]->count(),
this->update_[param_id]->gpu_data(),
this->temp_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

// jointly compute the RMS of both for update and gradient history
caffe_gpu_powx(net_params[param_id]->count(),
this->update_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());

// compute the update and copy to net_diff
caffe_gpu_mul(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(),
this->update_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());

// compute square of update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history of updates
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
this->update_[param_id]->gpu_data(), momentum,
this->history_[update_history_offset + param_id]->mutable_gpu_data());

// apply learning rate
caffe_gpu_scale(net_params[param_id]->count(), local_rate,
net_params[param_id]->gpu_diff(),
net_params[param_id]->mutable_gpu_diff());
adadelta_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
this->history_[param_id]->mutable_gpu_data(),
this->history_[update_history_offset + param_id]->mutable_gpu_data(),
momentum, delta, local_rate);
#else
NO_GPU;
#endif
Expand Down
30 changes: 30 additions & 0 deletions src/caffe/solvers/adadelta_solver.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "caffe/util/math_functions.hpp"


namespace caffe {

template <typename Dtype>
__global__ void AdaDeltaUpdate(int N, Dtype* g, Dtype* h, Dtype* h2,
Dtype momentum, Dtype delta, Dtype local_rate) {
CUDA_KERNEL_LOOP(i, N) {
float gi = g[i];
float hi = h[i] = momentum * h[i] + (1-momentum) * gi * gi;
gi = gi * sqrt((h2[i] + delta) / (hi + delta));
h2[i] = momentum * h2[i] + (1-momentum) * gi * gi;
g[i] = local_rate * gi;
}
}
template <typename Dtype>
void adadelta_update_gpu(int N, Dtype* g, Dtype* h, Dtype* h2, Dtype momentum,
Dtype delta, Dtype local_rate) {
AdaDeltaUpdate<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
N, g, h, h2, momentum, delta, local_rate);
CUDA_POST_KERNEL_CHECK;
}
template void adadelta_update_gpu<float>(int , float*, float*, float*,
float, float, float);
template void adadelta_update_gpu<double>(int, double*, double*, double*,
double, double, double);

} // namespace caffe
37 changes: 9 additions & 28 deletions src/caffe/solvers/adagrad_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

namespace caffe {

#ifndef CPU_ONLY
template <typename Dtype>
void adagrad_update_gpu(int N, Dtype* g, Dtype* h, Dtype delta,
Dtype local_rate);
#endif

template <typename Dtype>
void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
CHECK(Caffe::root_solver());
Expand Down Expand Up @@ -45,34 +51,9 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
case Caffe::GPU: {
#ifndef CPU_ONLY
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history
caffe_gpu_add(net_params[param_id]->count(),
this->update_[param_id]->gpu_data(),
this->history_[param_id]->gpu_data(),
this->history_[param_id]->mutable_gpu_data());

// prepare update
caffe_gpu_powx(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_gpu_data());

caffe_gpu_div(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(),
this->update_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

// scale and copy
caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->gpu_data(), Dtype(0),
net_params[param_id]->mutable_gpu_diff());
adagrad_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
this->history_[param_id]->mutable_gpu_data(), delta, local_rate);
#else
NO_GPU;
#endif
Expand Down
26 changes: 26 additions & 0 deletions src/caffe/solvers/adagrad_solver.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "caffe/util/math_functions.hpp"


namespace caffe {

template <typename Dtype>
__global__ void AdaGradUpdate(int N, Dtype* g, Dtype* h, Dtype delta,
Dtype local_rate) {
CUDA_KERNEL_LOOP(i, N) {
float gi = g[i];
float hi = h[i] = h[i] + gi*gi;
g[i] = local_rate * gi / (sqrt(hi) + delta);
}
}
template <typename Dtype>
void adagrad_update_gpu(int N, Dtype* g, Dtype* h, Dtype delta,
Dtype local_rate) {
AdaGradUpdate<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
N, g, h, delta, local_rate);
CUDA_POST_KERNEL_CHECK;
}
template void adagrad_update_gpu<float>(int, float*, float*, float, float);
template void adagrad_update_gpu<double>(int, double*, double*, double, double);

} // namespace caffe
37 changes: 9 additions & 28 deletions src/caffe/solvers/adam_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ void AdamSolver<Dtype>::AdamPreSolve() {
}
}

#ifndef CPU_ONLY
template <typename Dtype>
void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate);
#endif

template <typename Dtype>
void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
Expand Down Expand Up @@ -69,34 +75,9 @@ void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
case Caffe::GPU: {
#ifndef CPU_ONLY
// update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
caffe_gpu_axpby(N, Dtype(1)-beta1,
net_params[param_id]->gpu_diff(), beta1,
val_m->mutable_gpu_data());

// update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
caffe_gpu_mul(N,
net_params[param_id]->gpu_diff(),
net_params[param_id]->gpu_diff(),
val_t->mutable_gpu_data());
caffe_gpu_axpby(N, Dtype(1)-beta2,
val_t->gpu_data(), beta2,
val_v->mutable_gpu_data());

// set update
caffe_gpu_powx(N,
val_v->gpu_data(), Dtype(0.5),
val_t->mutable_gpu_data());
caffe_gpu_add_scalar(N, eps_hat,
val_t->mutable_gpu_data());
caffe_gpu_div(N,
val_m->gpu_data(),
val_t->gpu_data(),
val_t->mutable_gpu_data());

caffe_gpu_scale(N, local_rate*correction,
val_t->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
adam_update_gpu(N, net_params[param_id]->mutable_gpu_diff(),
val_m->mutable_gpu_data(), val_v->mutable_gpu_data(), beta1, beta2,
eps_hat, local_rate*correction);
#else
NO_GPU;
#endif
Expand Down
29 changes: 29 additions & 0 deletions src/caffe/solvers/adam_solver.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "caffe/util/math_functions.hpp"


namespace caffe {

template <typename Dtype>
__global__ void AdamUpdate(int N, Dtype* g, Dtype* m, Dtype* v,
Dtype beta1, Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate) {
CUDA_KERNEL_LOOP(i, N) {
float gi = g[i];
float mi = m[i] = m[i]*beta1 + gi*(1-beta1);
float vi = v[i] = v[i]*beta2 + gi*gi*(1-beta2);
g[i] = corrected_local_rate * mi / (sqrt(vi) + eps_hat);
}
}
template <typename Dtype>
void adam_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate) {
AdamUpdate<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
N, g, m, v, beta1, beta2, eps_hat, corrected_local_rate);
CUDA_POST_KERNEL_CHECK;
}
template void adam_update_gpu<float>(int, float*, float*, float*,
float, float, float, float);
template void adam_update_gpu<double>(int, double*, double*, double*,
double, double, double, double);

} // namespace caffe
29 changes: 10 additions & 19 deletions src/caffe/solvers/nesterov_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

namespace caffe {

#ifndef CPU_ONLY
template <typename Dtype>
void nesterov_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum,
Dtype local_rate);
#endif

template <typename Dtype>
void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
CHECK(Caffe::root_solver());
Expand Down Expand Up @@ -36,25 +42,10 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
case Caffe::GPU: {
#ifndef CPU_ONLY
// save history momentum for stepping back
caffe_copy(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

// update history
caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->gpu_diff(), momentum,
this->history_[param_id]->mutable_gpu_data());

// compute update: step back then over step
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
this->history_[param_id]->gpu_data(), -momentum,
this->update_[param_id]->mutable_gpu_data());

// copy
caffe_copy(net_params[param_id]->count(),
this->update_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
nesterov_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
this->history_[param_id]->mutable_gpu_data(),
momentum, local_rate);
#else
NO_GPU;
#endif
Expand Down
27 changes: 27 additions & 0 deletions src/caffe/solvers/nesterov_solver.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "caffe/util/math_functions.hpp"


namespace caffe {

template <typename Dtype>
__global__ void NesterovUpdate(int N, Dtype* g, Dtype* h,
Dtype momentum, Dtype local_rate) {
CUDA_KERNEL_LOOP(i, N) {
float hi = h[i];
float hi_new = h[i] = momentum * hi + local_rate * g[i];
g[i] = (1+momentum) * hi_new - momentum * hi;
}
}
template <typename Dtype>
void nesterov_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum,
Dtype local_rate) {
NesterovUpdate<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
N, g, h, momentum, local_rate);
CUDA_POST_KERNEL_CHECK;
}
template void nesterov_update_gpu<float>(int, float*, float*, float, float);
template void nesterov_update_gpu<double>(int, double*, double*, double,
double);

} // namespace caffe
35 changes: 10 additions & 25 deletions src/caffe/solvers/rmsprop_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

namespace caffe {

#ifndef CPU_ONLY
template <typename Dtype>
void rmsprop_update_gpu(int N, Dtype* g, Dtype* h, Dtype rms_decay,
Dtype delta, Dtype local_rate);
#endif

template <typename Dtype>
void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
Expand Down Expand Up @@ -45,31 +51,10 @@ void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
break;
case Caffe::GPU:
#ifndef CPU_ONLY
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history
caffe_gpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
rms_decay, this->history_[param_id]-> mutable_gpu_data());

// prepare update
caffe_gpu_powx(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_gpu_data());

caffe_gpu_div(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->gpu_data(), Dtype(0),
net_params[param_id]->mutable_gpu_diff());
rmsprop_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
this->history_[param_id]->mutable_gpu_data(),
rms_decay, delta, local_rate);
#else
NO_GPU;
#endif
Expand Down
Loading

0 comments on commit 6d09ca2

Please sign in to comment.