Skip to content

Commit

Permalink
Make backward pass work when global stats is active for BatchNormLayer
Browse files Browse the repository at this point in the history
including minor code cleaning
  • Loading branch information
kkhoot committed Nov 10, 2015
1 parent de015c5 commit 641e3b4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
12 changes: 5 additions & 7 deletions src/caffe/layers/batch_norm_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,11 @@ template <typename Dtype>
void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
CHECK(!use_global_stats_);
const Dtype* top_diff;
if (bottom[0] != top[0]) {
top_diff = top[0]->cpu_diff();
} else {
caffe_copy(x_norm_.count(), top[0]->cpu_diff(), x_norm_.mutable_cpu_diff());
top_diff = x_norm_.cpu_diff();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
if (use_global_stats_) {
caffe_div(temp_.count(), top_diff, temp_.cpu_data(), bottom_diff);
return;
}
const Dtype* top_data = x_norm_.cpu_data();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
Expand Down
12 changes: 5 additions & 7 deletions src/caffe/layers/batch_norm_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,11 @@ template <typename Dtype>
void BatchNormLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
CHECK(!use_global_stats_);
const Dtype* top_diff;
if (bottom[0] != top[0]) {
top_diff = top[0]->gpu_diff();
} else {
caffe_copy(x_norm_.count(), top[0]->gpu_diff(), x_norm_.mutable_gpu_diff());
top_diff = x_norm_.gpu_diff();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
if (use_global_stats_) {
caffe_gpu_div(temp_.count(), top_diff, temp_.gpu_data(), bottom_diff);
return;
}
const Dtype* top_data = x_norm_.gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
Expand Down

0 comments on commit 641e3b4

Please sign in to comment.