-
Notifications
You must be signed in to change notification settings - Fork 18.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exposing HDF5 saving and loading to python #4227
Conversation
Would you considering updating this to also add loadHDF5 as well? If the point of this is to save models with tensors > 2GB, then it would be nice to be able to load them in PyCaffe as well. Here's a patch that I'd suggest modifying yours to: diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -101,6 +101,14 @@
WriteProtoToBinaryFile(net_param, filename.c_str());
}
+void Net_SaveHDF5(const Net<Dtype>& net, string filename) {
+ net.ToHDF5(filename.c_str(), false);
+}
+
+void Net_LoadHDF5(Net<Dtype>* net, string filename) {
+ net->CopyTrainedLayersFromHDF5(filename.c_str());
+}
+
void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
bp::object labels_obj) {
// check that this network has an input MemoryDataLayer
@@ -254,6 +262,8 @@
bp::return_value_policy<bp::copy_const_reference>()))
.def("_set_input_arrays", &Net_SetInputArrays,
bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
+ .def("load_hdf5", &Net_LoadHDF5)
+ .def("save_hdf5", &Net_SaveHDF5)
.def("save", &Net_Save);
bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py
--- a/python/caffe/test/test_net.py
+++ b/python/caffe/test/test_net.py
@@ -79,3 +79,17 @@
for i in range(len(self.net.params[name])):
self.assertEqual(abs(self.net.params[name][i].data
- net2.params[name][i].data).sum(), 0)
+
+ def test_save_hdf5(self):
+ f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
+ f.close()
+ self.net.save_hdf5(f.name)
+ net_file = simple_net_file(self.num_output)
+ net2 = caffe.Net(net_file, caffe.TRAIN)
+ net2.load_hdf5(f.name)
+ os.remove(net_file)
+ os.remove(f.name)
+ for name in self.net.params:
+ for i in range(len(self.net.params[name])):
+ self.assertEqual(abs(self.net.params[name][i].data
+ - net2.params[name][i].data).sum(), 0) |
thank you @ajtulloch , I just added the load_hdf5 function. |
LGTM. @longjon, @shelhamer? |
self.net.save_hdf5(f.name) | ||
net_file = simple_net_file(self.num_output) | ||
net2 = caffe.Net(net_file, caffe.TRAIN) | ||
net2.load_hdf5(f.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're using camel case in the _caffe.cpp file (saveHDF5, loadHDF5), but snake case here. I think for consistency you should use save_hdf5, load_hdf5 in the _caffe.cpp file. It's Python PEP-8 style to use snake_case for member functions.
Could you fix the casing issue? That will fix travis. |
@philkr I think you meant |
@shelhamer yes I did, changed it now. |
Thanks Philipp for exposing hdf5 net serialization to pycaffe. |
Thanks @philkr! |
[pycaffe] expose saving/loading nets as hdf5 to python
Exposes a new function 'Net.saveHDF5' to python. This allows caffe to store model weights in the hdf5 format (which can easily be used in other libraries).