Skip to content

Commit

Permalink
Merge pull request BVLC#2095 from mtamburrano/skip_propagate_down_param
Browse files Browse the repository at this point in the history
Added param skip_propagate_down to LayerParameter
  • Loading branch information
weiliu89 committed Apr 14, 2015
2 parents 02c0ab1 + 8c3625f commit e7b6d3e
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 4 deletions.
3 changes: 3 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class Net {
inline const vector<Dtype>& blob_loss_weights() const {
return blob_loss_weights_;
}
inline const vector<bool>& layer_need_backward() const {
return layer_need_backward_;
}
/// @brief returns the parameters
inline const vector<shared_ptr<Blob<Dtype> > >& params() const {
return params_;
Expand Down
43 changes: 39 additions & 4 deletions src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,18 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
}
// Setup layer.
const LayerParameter& layer_param = param.layer(layer_id);
if (layer_param.propagate_down_size() > 0) {
CHECK_EQ(layer_param.propagate_down_size(),
layer_param.bottom_size())
<< "propagate_down param must be specified"
<< "either 0 or bottom_size times ";
}
layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
layers_[layer_id]->set_net(this);
layer_names_.push_back(layer_param.name());
LOG(INFO) << "Creating Layer " << layer_param.name();
bool need_backward = false;

// Figure out this layer's input and output
for (int bottom_id = 0; bottom_id < layer_param.bottom_size();
++bottom_id) {
Expand Down Expand Up @@ -152,15 +159,33 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
// Go through the net backwards to determine which blobs contribute to the
// loss. We can skip backward computation for blobs that don't contribute
// to the loss.
// Also checks if all bottom blobs don't need backward computation (possible
// because the skip_propagate_down param) and so we can skip bacward
// computation for the entire layer
set<string> blobs_under_loss;
set<string> blobs_skip_backp;
for (int layer_id = layers_.size() - 1; layer_id >= 0; --layer_id) {
bool layer_contributes_loss = false;
bool layer_skip_propagate_down = true;
for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
if (layers_[layer_id]->loss(top_id) ||
(blobs_under_loss.find(blob_name) != blobs_under_loss.end())) {
layer_contributes_loss = true;
}
if (blobs_skip_backp.find(blob_name) == blobs_skip_backp.end()) {
layer_skip_propagate_down = false;
}
if (layer_contributes_loss && !layer_skip_propagate_down)
break;
}
// If this layer can skip backward computation, also all his bottom blobs
// don't need backpropagation
if (layer_need_backward_[layer_id] && layer_skip_propagate_down) {
layer_need_backward_[layer_id] = false;
for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size();
++bottom_id) {
bottom_need_backward_[layer_id][bottom_id] = false;
}
}
if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; }
Expand All @@ -179,6 +204,11 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
} else {
bottom_need_backward_[layer_id][bottom_id] = false;
}
if (!bottom_need_backward_[layer_id][bottom_id]) {
const string& blob_name =
blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
blobs_skip_backp.insert(blob_name);
}
}
}
// Handle force_backward if needed.
Expand Down Expand Up @@ -368,9 +398,9 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,

// Helper for Net::Init: add a new bottom blob to the net.
template <typename Dtype>
int Net<Dtype>::AppendBottom(const NetParameter& param,
const int layer_id, const int bottom_id,
set<string>* available_blobs, map<string, int>* blob_name_to_idx) {
int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
const int bottom_id, set<string>* available_blobs,
map<string, int>* blob_name_to_idx) {
const LayerParameter& layer_param = param.layer(layer_id);
const string& blob_name = layer_param.bottom(bottom_id);
if (available_blobs->find(blob_name) == available_blobs->end()) {
Expand All @@ -382,7 +412,12 @@ int Net<Dtype>::AppendBottom(const NetParameter& param,
bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
bottom_id_vecs_[layer_id].push_back(blob_id);
available_blobs->erase(blob_name);
const bool need_backward = blob_need_backward_[blob_id];
bool propagate_down = true;
// Check if the backpropagation on bottom_id should be skipped
if (layer_param.propagate_down_size() > 0)
propagate_down = layer_param.propagate_down(bottom_id);
const bool need_backward = blob_need_backward_[blob_id] &
propagate_down;
bottom_need_backward_[layer_id].push_back(need_backward);
return blob_id;
}
Expand Down
4 changes: 4 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ message LayerParameter {

// The blobs containing the numeric parameters of the layer.
repeated BlobProto blobs = 7;

// Specifies on which bottoms the backpropagation should be skipped.
// The size must be either 0 or equals to the number of bottoms.
repeated bool propagate_down = 11;

// Rules controlling whether and when a layer is included in the network,
// based on the current NetState. You may specify a non-zero number of rules
Expand Down
145 changes: 145 additions & 0 deletions src/caffe/test/test_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,103 @@ class NetTest : public MultiDeviceTest<TypeParam> {
InitNetFromProtoString(proto);
}

virtual void InitSkipPropNet(bool test_skip_true) {
string proto =
"name: 'SkipPropTestNetwork' "
"layer { "
" name: 'data' "
" type: 'DummyData' "
" dummy_data_param { "
" shape { "
" dim: 5 "
" dim: 2 "
" dim: 3 "
" dim: 4 "
" } "
" data_filler { "
" type: 'gaussian' "
" std: 0.01 "
" } "
" shape { "
" dim: 5 "
" } "
" data_filler { "
" type: 'constant' "
" value: 0 "
" } "
" } "
" top: 'data' "
" top: 'label' "
"} "
"layer { "
" name: 'silence' "
" bottom: 'label' "
" type: 'Silence' "
"} "
"layer { "
" name: 'innerproduct' "
" type: 'InnerProduct' "
" inner_product_param { "
" num_output: 1 "
" weight_filler { "
" type: 'gaussian' "
" std: 0.01 "
" } "
" bias_filler { "
" type: 'constant' "
" value: 0 "
" } "
" } "
" param { "
" lr_mult: 1 "
" decay_mult: 1 "
" } "
" param { "
" lr_mult: 2 "
" decay_mult: 0 "
" } "
" bottom: 'data' "
" top: 'innerproduct' "
"} "
"layer { "
" name: 'ip_fake_labels' "
" type: 'InnerProduct' "
" inner_product_param { "
" num_output: 1 "
" weight_filler { "
" type: 'gaussian' "
" std: 0.01 "
" } "
" bias_filler { "
" type: 'constant' "
" value: 0 "
" } "
" } "
" bottom: 'data' "
" top: 'fake_labels' "
"} "
"layer { "
" name: 'argmax' "
" bottom: 'fake_labels' "
" top: 'label_argmax' "
" type: 'ArgMax' "
"} "
"layer { "
" name: 'loss' "
" bottom: 'innerproduct' "
" bottom: 'label_argmax' ";
if (test_skip_true)
proto += " propagate_down: [true, false] ";
else
proto += " propagate_down: [true, true] ";
proto +=
" top: 'cross_entropy_loss' "
" type: 'SigmoidCrossEntropyLoss' "
" loss_weight: 0.1 "
"} ";
InitNetFromProtoString(proto);
}

int seed_;
shared_ptr<Net<Dtype> > net_;
};
Expand Down Expand Up @@ -2224,4 +2321,52 @@ TYPED_TEST(NetTest, TestReshape) {
}
}

TYPED_TEST(NetTest, TestSkipPropagateDown) {
// check bottom_need_backward if propagate_down is true
this->InitSkipPropNet(false);
vector<bool> vec_layer_need_backward = this->net_->layer_need_backward();
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
if (this->net_->layer_names()[layer_id] == "loss") {
// access to bottom_need_backward coresponding to label's blob
bool need_back = this->net_->bottom_need_backward()[layer_id][1];
// if propagate_down is true, the loss layer will try to
// backpropagate on labels
CHECK_EQ(need_back, true)
<< "bottom_need_backward should be True";
}
if (this->net_->layer_names()[layer_id] == "ip_fake_labels")
CHECK_EQ(vec_layer_need_backward[layer_id], true)
<< "layer_need_backward for ip_fake_labels should be True";
if (this->net_->layer_names()[layer_id] == "argmax")
CHECK_EQ(vec_layer_need_backward[layer_id], true)
<< "layer_need_backward for argmax should be True";
if (this->net_->layer_names()[layer_id] == "innerproduct")
CHECK_EQ(vec_layer_need_backward[layer_id], true)
<< "layer_need_backward for innerproduct should be True";
}
// check bottom_need_backward if propagat_down is false
this->InitSkipPropNet(true);
vec_layer_need_backward.clear();
vec_layer_need_backward = this->net_->layer_need_backward();
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
if (this->net_->layer_names()[layer_id] == "loss") {
// access to bottom_need_backward coresponding to label's blob
bool need_back = this->net_->bottom_need_backward()[layer_id][1];
// if propagate_down is false, the loss layer will not try to
// backpropagate on labels
CHECK_EQ(need_back, false)
<< "bottom_need_backward should be False";
}
if (this->net_->layer_names()[layer_id] == "ip_fake_labels")
CHECK_EQ(vec_layer_need_backward[layer_id], false)
<< "layer_need_backward for ip_fake_labels should be False";
if (this->net_->layer_names()[layer_id] == "argmax")
CHECK_EQ(vec_layer_need_backward[layer_id], false)
<< "layer_need_backward for argmax should be False";
if (this->net_->layer_names()[layer_id] == "innerproduct")
CHECK_EQ(vec_layer_need_backward[layer_id], true)
<< "layer_need_backward for innerproduct should be True";
}
}

} // namespace caffe

0 comments on commit e7b6d3e

Please sign in to comment.