diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 9f9d975dfba..fa20a1c2698 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -5,6 +5,7 @@ #include #include +#include "boost/algorithm/string.hpp" #include "caffe/caffe.hpp" using caffe::Blob; @@ -76,6 +77,19 @@ int device_query() { } RegisterBrewFunction(device_query); +// Load the weights from the specified caffemodel(s) into the train- and +// test-nets. +void CopyLayers(caffe::Solver* solver, const std::string& model_list) { + std::vector model_names; + boost::split(model_names, model_list, boost::is_any_of(",") ); + for (int i = 0; i < model_names.size(); ++i) { + LOG(INFO) << "Finetuning from " << model_names[i]; + solver->net()->CopyTrainedLayersFrom(model_names[i]); + for (int j = 0; j < solver->test_nets().size(); ++j) { + solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]); + } + } +} // Train / Finetune a model. int train() { @@ -112,8 +126,7 @@ int train() { LOG(INFO) << "Resuming from " << FLAGS_snapshot; solver->Solve(FLAGS_snapshot); } else if (FLAGS_weights.size()) { - LOG(INFO) << "Finetuning from " << FLAGS_weights; - solver->net()->CopyTrainedLayersFrom(FLAGS_weights); + CopyLayers(&*solver, FLAGS_weights); solver->Solve(); } else { solver->Solve();