diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index ce82e28602c..44f28b2a04f 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -35,6 +35,16 @@ class Net { // Run forward with the input blobs already fed separately. You can get the // input blobs using input_blobs(). const vector*>& ForwardPrefilled(Dtype* loss = NULL); + + // The From and To variants of Forward and Backward operate on the + // (topological) ordering by which the net is specified. For general DAG + // networks, note that (1) computing from one layer to another might entail + // extra computation on unrelated branches, and (2) computation starting in + // the middle may be incorrect if all of the layers of a fan-in are not + // included. + Dtype ForwardFromTo(int start, int end); + Dtype ForwardFrom(int start); + Dtype ForwardTo(int end); // Run forward using a set of bottom blobs, and return the result. const vector*>& Forward(const vector* > & bottom, Dtype* loss = NULL); @@ -46,6 +56,9 @@ class Net { // computes the gradient w.r.t the parameters, and the data has already // been provided during the forward pass. void Backward(); + void BackwardFromTo(int start, int end); + void BackwardFrom(int start); + void BackwardTo(int end); Dtype ForwardBackward(const vector* > & bottom) { Dtype loss; diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index e9fe5cd3b05..10fc23b7974 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -181,12 +181,12 @@ struct CaffeNet { } } - void Forward() { - net_->ForwardPrefilled(); + void Forward(int start, int end) { + net_->ForwardFromTo(start, end); } - void Backward() { - net_->Backward(); + void Backward(int start, int end) { + net_->BackwardFromTo(start, end); } void set_input_arrays(object data_obj, object labels_obj) { diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 5c1512cd8b9..0ac1868663d 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -34,8 +34,7 @@ def _Net_params(self): return OrderedDict([(lr.name, lr.blobs) for lr in self.layers if len(lr.blobs) > 0]) - -def _Net_forward(self, blobs=None, **kwargs): +def _Net_forward(self, blobs=None, start=None, end=None, **kwargs): """ Forward pass: prepare inputs and run the net forward. @@ -44,6 +43,8 @@ def _Net_forward(self, blobs=None, **kwargs): kwargs: Keys are input blob names and values are blob ndarrays. For formatting inputs for Caffe, see Net.preprocess(). If None, input is taken from data layers. + start: optional name of layer at which to begin the forward pass + end: optional name of layer at which to finish the forward pass (inclusive) Give outs: {blob name: blob ndarray} dict. @@ -51,6 +52,18 @@ def _Net_forward(self, blobs=None, **kwargs): if blobs is None: blobs = [] + if start is not None: + start_ind = [lr.name for lr in self.layers].index(start) + else: + start_ind = 0 + + if end is not None: + end_ind = [lr.name for lr in self.layers].index(end) + outputs = set([end] + blobs) + else: + end_ind = len(self.layers) - 1 + outputs = set(self.outputs + blobs) + if kwargs: if set(kwargs.keys()) != set(self.inputs): raise Exception('Input blob arguments do not match net inputs.') @@ -63,14 +76,13 @@ def _Net_forward(self, blobs=None, **kwargs): raise Exception('{} blob is not 4-d'.format(in_)) self.blobs[in_].data[...] = blob - self._forward() + self._forward(start_ind, end_ind) # Unpack blobs to extract - outs = {out: self.blobs[out].data for out in set(self.outputs + blobs)} - return outs + return {out: self.blobs[out].data for out in outputs} -def _Net_backward(self, diffs=None, **kwargs): +def _Net_backward(self, diffs=None, start=None, end=None, **kwargs): """ Backward pass: prepare diffs and run the net backward. @@ -78,6 +90,8 @@ def _Net_backward(self, diffs=None, **kwargs): diffs: list of diffs to return in addition to bottom diffs. kwargs: Keys are output blob names and values are diff ndarrays. If None, top diffs are taken from forward loss. + start: optional name of layer at which to begin the backward pass + end: optional name of layer at which to finish the backward pass (inclusive) Give outs: {blob name: diff ndarray} dict. @@ -85,6 +99,18 @@ def _Net_backward(self, diffs=None, **kwargs): if diffs is None: diffs = [] + if start is not None: + start_ind = [lr.name for lr in self.layers].index(start) + else: + start_ind = len(self.layers) - 1 + + if end is not None: + end_ind = [lr.name for lr in self.layers].index(end) + outputs = set([end] + diffs) + else: + end_ind = 0 + outputs = set(self.inputs + diffs) + if kwargs: if set(kwargs.keys()) != set(self.outputs): raise Exception('Top diff arguments do not match net outputs.') @@ -97,11 +123,10 @@ def _Net_backward(self, diffs=None, **kwargs): raise Exception('{} diff is not 4-d'.format(top)) self.blobs[top].diff[...] = diff - self._backward() + self._backward(start_ind, end_ind) # Unpack diffs to extract - outs = {out: self.blobs[out].diff for out in set(self.inputs + diffs)} - return outs + return {out: self.blobs[out].diff for out in outputs} def _Net_forward_all(self, blobs=None, **kwargs): diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index aba4cc2f4f2..cadcdcdf109 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -335,16 +335,34 @@ void Net::GetLearningRateAndWeightDecay() { } template -const vector*>& Net::ForwardPrefilled(Dtype* loss) { - if (loss != NULL) { - *loss = Dtype(0.); - } - for (int i = 0; i < layers_.size(); ++i) { +Dtype Net::ForwardFromTo(int start, int end) { + CHECK_GE(start, 0); + CHECK_LT(end, layers_.size()); + Dtype loss = 0; + for (int i = start; i <= end; ++i) { // LOG(ERROR) << "Forwarding " << layer_names_[i]; Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]); - if (loss != NULL) { - *loss += layer_loss; - } + loss += layer_loss; + } + return loss; +} + +template +Dtype Net::ForwardFrom(int start) { + return ForwardFromTo(start, layers_.size() - 1); +} + +template +Dtype Net::ForwardTo(int end) { + return ForwardFromTo(0, end); +} + +template +const vector*>& Net::ForwardPrefilled(Dtype* loss) { + if (loss != NULL) { + *loss = ForwardFromTo(0, layers_.size() - 1); + } else { + ForwardFromTo(0, layers_.size() - 1); } return net_output_blobs_; } @@ -380,10 +398,11 @@ string Net::Forward(const string& input_blob_protos, Dtype* loss) { return output; } - template -void Net::Backward() { - for (int i = layers_.size() - 1; i >= 0; --i) { +void Net::BackwardFromTo(int start, int end) { + CHECK_GE(end, 0); + CHECK_LT(start, layers_.size()); + for (int i = start; i >= end; --i) { if (layer_need_backward_[i]) { layers_[i]->Backward( top_vecs_[i], bottom_need_backward_[i], &bottom_vecs_[i]); @@ -422,6 +441,21 @@ void Net::ShareTrainedLayersWith(Net* other) { } } +template +void Net::BackwardFrom(int start) { + BackwardFromTo(start, 0); +} + +template +void Net::BackwardTo(int end) { + BackwardFromTo(layers_.size() - 1, end); +} + +template +void Net::Backward() { + BackwardFromTo(layers_.size() - 1, 0); +} + template void Net::CopyTrainedLayersFrom(const NetParameter& param) { int num_source_layers = param.layers_size(); diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp index c30f4168edc..e84701d941c 100644 --- a/src/caffe/test/test_net.cpp +++ b/src/caffe/test/test_net.cpp @@ -801,4 +801,38 @@ TYPED_TEST(NetTest, TestParamPropagateDown) { } } +TYPED_TEST(NetTest, TestFromTo) { + typedef typename TypeParam::Dtype Dtype; + this->InitTinyNet(); + + // Run Forward and Backward, recording the data diff and loss. + Blob data; + data.ReshapeLike(*this->net_->blob_by_name("data")); + this->net_->ForwardPrefilled(); + this->net_->Backward(); + data.CopyFrom(*this->net_->blob_by_name("data"), true, true); + const Dtype *loss_ptr = this->net_->output_blobs()[0]->cpu_data(); + Dtype loss = *loss_ptr; + + // Check that combining partial Forwards gives the same loss. + for (int i = 1; i < this->net_->layers().size(); ++i) { + // Note that we skip layer zero to keep the same data. + this->net_->ForwardFromTo(1, 1); + if (i < this->net_->layers().size() - 1) { + this->net_->ForwardFrom(i + 1); + } + EXPECT_EQ(loss, *loss_ptr); + } + + // Check that combining partial Backwards gives the same data diff. + for (int i = 1; i < this->net_->layers().size(); ++i) { + this->net_->BackwardTo(i); + this->net_->BackwardFrom(i - 1); + for (int j = 0; j < data.count(); ++j) { + EXPECT_EQ(data.cpu_diff()[j], + this->net_->blob_by_name("data")->cpu_diff()[j]); + } + } +} + } // namespace caffe