Skip to content

Commit

Permalink
Merge pull request BVLC#3 from Russell91/load_data
Browse files Browse the repository at this point in the history
Load data
  • Loading branch information
Russell Stewart authored and Russell Stewart committed Aug 18, 2015
2 parents 5ef0bfd + cd9dcde commit 987b765
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 16 deletions.
5 changes: 5 additions & 0 deletions include/caffe/apollonet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class ApolloNet {
CopyTrainedLayersFrom(param);
}

void CopyLayerFrom(const LayerParameter& source_layer);

void SaveTrainedLayersTo(const string trained_filename) const;

void Update(Dtype lr, Dtype momentum, Dtype clip_gradients,
Dtype weight_decay);

Expand Down Expand Up @@ -99,6 +103,7 @@ class ApolloNet {
map<string, Dtype> param_lr_mults_;
map<string, vector<shared_ptr<Blob<Dtype> > > > bottom_blobs_;
map<string, vector<string> > bottom_blob_names_;
map<string, LayerParameter> param_cache_;
vector<string> active_layers_vec_;
set<string> active_layers_set_;
set<string> active_params_set_;
Expand Down
17 changes: 11 additions & 6 deletions python/apollocaffe/cpp/_apollocaffe.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,21 @@ cdef class ApolloNet:
return blobs

def save(self, filename):
assert filename.endswith('.h5'), "saving only supports h5 files"
with h5py.File(filename, 'w') as f:
for name, value in self.params.items():
f[name] = pynp.copy(value.data)
_, extension = os.path.splitext(filename)
if extension == '.h5':
with h5py.File(filename, 'w') as f:
for name, value in self.params.items():
f[name] = pynp.copy(value.data)
elif extension == '.caffemodel':
self.thisptr.SaveTrainedLayersTo(filename)
else:
assert False, "Error, filename is neither h5 nor caffemodel: %s, %s" % (filename, extension)

def load(self, filename):
if len(self.params) == 0:
raise ValueError('WARNING, loading into empty net.')
_, extension = os.path.splitext(filename)
if extension == '.h5':
if len(self.params) == 0:
raise ValueError('WARNING, loading into empty net.')
with h5py.File(filename, 'r') as f:
params = self.params
names = []
Expand Down
1 change: 1 addition & 0 deletions python/apollocaffe/cpp/definitions.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ cdef extern from "caffe/apollonet.hpp" namespace "caffe":
void set_phase_train()
Phase phase()
void CopyTrainedLayersFrom(string trained_filename) except +
void SaveTrainedLayersTo(string trained_filename) except +
vector[string]& active_layer_names()
set[string]& active_param_names()

Expand Down
49 changes: 39 additions & 10 deletions src/caffe/apollonet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/upgrade_proto.hpp"

Expand Down Expand Up @@ -118,6 +119,9 @@ Dtype ApolloNet<Dtype>::ForwardLayer(shared_ptr<Layer<Dtype> > layer) {
if (new_layer) {
layer->SetUp(bottom_vec, top_vec);
AddLayerParams(layer);
if (param_cache_.find(layer_name) != param_cache_.end()) {
CopyLayerFrom(param_cache_[layer_name]);
}
}

for (int param_id = 0; param_id < layer->param_names().size(); ++param_id) {
Expand Down Expand Up @@ -247,6 +251,9 @@ Dtype ApolloNet<Dtype>::ForwardLayer(const string& layer_param_string) {
if (new_layer) {
layer->SetUp(bottom_vec, top_vec);
AddLayerParams(layer);
if (param_cache_.find(layer_name) != param_cache_.end()) {
CopyLayerFrom(param_cache_[layer_name]);
}
}

for (int param_id = 0; param_id < layer->param_names().size(); ++param_id) {
Expand Down Expand Up @@ -404,21 +411,43 @@ void ApolloNet<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
const string& source_layer_name = source_layer.name();

if (layers_map_.find(source_layer_name) == layers_map_.end()) {
LOG(INFO) << "Ignoring source layer " << source_layer_name;
param_cache_[source_layer_name] = source_layer;
LOG(INFO) << "Caching source layer blobs " << source_layer_name;
continue;
}
CopyLayerFrom(source_layer);
}
}

LOG(INFO) << "Copying source layer " << source_layer_name;
vector<shared_ptr<Blob<Dtype> > >& target_blobs =
layers_map_[source_layer_name]->blobs();
template <typename Dtype>
void ApolloNet<Dtype>::CopyLayerFrom(const LayerParameter& source_layer) {
const string& source_layer_name = source_layer.name();
LOG(INFO) << "Copying source layer blobs " << source_layer_name;
vector<shared_ptr<Blob<Dtype> > >& target_blobs =
layers_map_[source_layer_name]->blobs();

ASSERT(target_blobs.size() == source_layer.blobs_size(),
"Incompatible number of blobs for layer " << source_layer_name);
for (int j = 0; j < target_blobs.size(); ++j) {
const bool kReshape = false;
target_blobs[j]->FromProto(source_layer.blobs(j), kReshape);
}
}

ASSERT(target_blobs.size() == source_layer.blobs_size(),
"Incompatible number of blobs for layer " << source_layer_name);
for (int j = 0; j < target_blobs.size(); ++j) {
const bool kReshape = false;
target_blobs[j]->FromProto(source_layer.blobs(j), kReshape);
}
template <typename Dtype>
void ApolloNet<Dtype>::SaveTrainedLayersTo(const string trained_filename)
const {
NetParameter param;
DLOG(INFO) << "Serializing " << layers_map_.size() << " layers";
typename map<string, shared_ptr<Layer<Dtype> > >::const_iterator it =
layers_map_.begin();
while (it != layers_map_.end()) {
shared_ptr<Layer<Dtype> > layer = it->second;
LayerParameter* layer_param = param.add_layer();
layer->ToProto(layer_param);
++it;
}
WriteProtoToBinaryFile(param, trained_filename);
}

INSTANTIATE_CLASS(ApolloNet);
Expand Down

0 comments on commit 987b765

Please sign in to comment.