From fe7bf3e6a5f3825eece90c3d03b15edec562ce68 Mon Sep 17 00:00:00 2001 From: nicodjimenez Date: Mon, 17 Aug 2015 21:55:33 -0400 Subject: [PATCH] exposed new parameter saving / loading functionality to python --- python/apollocaffe/cpp/_apollocaffe.pyx | 13 +++++++++---- python/apollocaffe/cpp/definitions.pxd | 1 + 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/apollocaffe/cpp/_apollocaffe.pyx b/python/apollocaffe/cpp/_apollocaffe.pyx index 0b7056d4af1..43123239c0a 100644 --- a/python/apollocaffe/cpp/_apollocaffe.pyx +++ b/python/apollocaffe/cpp/_apollocaffe.pyx @@ -386,10 +386,15 @@ cdef class ApolloNet: return blobs def save(self, filename): - assert filename.endswith('.h5'), "saving only supports h5 files" - with h5py.File(filename, 'w') as f: - for name, value in self.params.items(): - f[name] = pynp.copy(value.data) + _, extension = os.path.splitext(filename) + if extension == '.h5': + with h5py.File(filename, 'w') as f: + for name, value in self.params.items(): + f[name] = pynp.copy(value.data) + elif extension == '.caffemodel': + self.thisptr.CopyTrainedLayersFrom(filename) + else: + assert False, "Error, filename is neither h5 nor caffemodel: %s, %s" % (filename, extension) def load(self, filename): if len(self.params) == 0: diff --git a/python/apollocaffe/cpp/definitions.pxd b/python/apollocaffe/cpp/definitions.pxd index 076d2ad5486..b2955565c0e 100644 --- a/python/apollocaffe/cpp/definitions.pxd +++ b/python/apollocaffe/cpp/definitions.pxd @@ -75,6 +75,7 @@ cdef extern from "caffe/apollonet.hpp" namespace "caffe": void set_phase_train() Phase phase() void CopyTrainedLayersFrom(string trained_filename) except + + void SaveTrainedLayersTo(string trained_filename) except + vector[string]& active_layer_names() set[string]& active_param_names()