Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing HDF5 saving and loading to python #4227

Merged
merged 1 commit into from
Jun 3, 2016
Merged

Conversation

philkr
Copy link
Contributor

@philkr philkr commented May 27, 2016

Exposes a new function 'Net.saveHDF5' to python. This allows caffe to store model weights in the hdf5 format (which can easily be used in other libraries).

@ajtulloch
Copy link
Contributor

Would you considering updating this to also add loadHDF5 as well? If the point of this is to save models with tensors > 2GB, then it would be nice to be able to load them in PyCaffe as well.

Here's a patch that I'd suggest modifying yours to:

diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -101,6 +101,14 @@
   WriteProtoToBinaryFile(net_param, filename.c_str());
 }

+void Net_SaveHDF5(const Net<Dtype>& net, string filename) {
+  net.ToHDF5(filename.c_str(), false);
+}
+
+void Net_LoadHDF5(Net<Dtype>* net, string filename) {
+  net->CopyTrainedLayersFromHDF5(filename.c_str());
+}
+
 void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
     bp::object labels_obj) {
   // check that this network has an input MemoryDataLayer
@@ -254,6 +262,8 @@
         bp::return_value_policy<bp::copy_const_reference>()))
     .def("_set_input_arrays", &Net_SetInputArrays,
         bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
+    .def("load_hdf5", &Net_LoadHDF5)
+    .def("save_hdf5", &Net_SaveHDF5)
     .def("save", &Net_Save);

   bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py
--- a/python/caffe/test/test_net.py
+++ b/python/caffe/test/test_net.py
@@ -79,3 +79,17 @@
             for i in range(len(self.net.params[name])):
                 self.assertEqual(abs(self.net.params[name][i].data
                     - net2.params[name][i].data).sum(), 0)
+
+    def test_save_hdf5(self):
+        f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
+        f.close()
+        self.net.save_hdf5(f.name)
+        net_file = simple_net_file(self.num_output)
+        net2 = caffe.Net(net_file, caffe.TRAIN)
+        net2.load_hdf5(f.name)
+        os.remove(net_file)
+        os.remove(f.name)
+        for name in self.net.params:
+            for i in range(len(self.net.params[name])):
+                self.assertEqual(abs(self.net.params[name][i].data
+                    - net2.params[name][i].data).sum(), 0)

@philkr
Copy link
Contributor Author

philkr commented May 29, 2016

thank you @ajtulloch , I just added the load_hdf5 function.

@ajtulloch
Copy link
Contributor

LGTM. @longjon, @shelhamer?

self.net.save_hdf5(f.name)
net_file = simple_net_file(self.num_output)
net2 = caffe.Net(net_file, caffe.TRAIN)
net2.load_hdf5(f.name)
Copy link
Contributor

@ajtulloch ajtulloch May 29, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're using camel case in the _caffe.cpp file (saveHDF5, loadHDF5), but snake case here. I think for consistency you should use save_hdf5, load_hdf5 in the _caffe.cpp file. It's Python PEP-8 style to use snake_case for member functions.

@ajtulloch
Copy link
Contributor

Could you fix the casing issue? That will fix travis.

@shelhamer
Copy link
Member

shelhamer commented Jun 2, 2016

@philkr please squash your commits for merge. We should expose hdf5 load/save to resolve #4192 and for the purpose of interoperability as you mentioned.

This is straightforward, but @longjon could also double-check the diff.

@shelhamer
Copy link
Member

@philkr I think you meant Exposing save_hdf5 and load_hdf5 to python for the commit message.

@philkr
Copy link
Contributor Author

philkr commented Jun 2, 2016

@shelhamer yes I did, changed it now.

@shelhamer shelhamer merged commit df412ac into BVLC:master Jun 3, 2016
@shelhamer
Copy link
Member

Thanks Philipp for exposing hdf5 net serialization to pycaffe.

@shelhamer shelhamer changed the title Exposing HDF5 saving to python Exposing HDF5 saving and loading to python Jun 3, 2016
@ajtulloch
Copy link
Contributor

Thanks @philkr!

fxbit pushed a commit to Yodigram/caffe that referenced this pull request Sep 1, 2016
[pycaffe] expose saving/loading nets as hdf5 to python
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants