-
Notifications
You must be signed in to change notification settings - Fork 1
/
extract_caffe_weights.py
96 lines (73 loc) · 3.17 KB
/
extract_caffe_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
__author__ = "C. Clayton Violand"
__copyright__ = "Copyright 2017"
## Extracts Caffe weights from Caffe models. Writes to file at: 'weights/<model_name>/..'.
## REQUIRED: Caffe .prototxt and .caffemodel in: 'models/<model_name>/..'
##
import os
import re
import sys
import numpy as np
import caffe
import cPickle
from helper.extract_architecture import extract_architecture
caffe.set_mode_cpu()
def main():
dirs = [x[0] for x in os.walk('models/')]
for d in dirs:
model_match = re.search("models/(.+)", d)
if model_match:
model = model_match.group(1)
else:
continue
if os.path.exists("weights/%s" % model_match.group(1)):
continue
# Get .prototxt and .caffemodel path.
for f in os.listdir(d):
if f.endswith('.prototxt'):
prototxt_file_path = os.path.join(d, f)
if f.endswith('.caffemodel'):
model_file_path = os.path.join(d, f)
try:
prototxt_file_path
except:
sys.exit("Error: No suitable Caffe .prototxt found...")
try:
model_file_path
except:
sys.exit("Error: No suitable .caffemodel file found...")
# Extract architecture and parameters.
architecture = extract_architecture(prototxt_file_path)
a = architecture
# Define caffe net.
net = caffe.Net(prototxt_file_path, model_file_path, caffe.TEST)
# Extract and write weights for each relevant layer.
for key in a:
if key == "shape" or a[key]['type'] == "relu" or a[key]['type'] == "pooling" or a[key]['type'] == "eltwise":
continue
if not os.path.exists(os.path.join('weights', model)):
os.makedirs(os.path.join('weights', model))
if a[key]['type'] == "batchnorm":
mean_blob = net.params[key][0].data[...]
var_blob = net.params[key][1].data[...]
np.savetxt(os.path.join('weights', model, key+"_mean.csv"), mean_blob, delimiter=',')
np.savetxt(os.path.join('weights', model, key+"_var.csv"), var_blob, delimiter=',')
continue
weight_blob = net.params[key][0].data[...]
if len(weight_blob.shape) == 4:
weight_blob = weight_blob.reshape(weight_blob.shape[0], weight_blob.shape[1]*weight_blob.shape[2]*weight_blob.shape[3])
elif len(weight_blob.shape) == 3:
weight_blob = weight_blob.reshape(weight_blob.shape[0], weight_blob.shape[1]*weight_blob.shape[2])
else:
pass
np.savetxt(os.path.join('weights', model, key+"_weights.csv"), weight_blob, delimiter=',')
if "bias_term" in a[key].keys():
if a[key]['bias_term'] == "false":
bias_blob = np.zeros(weight_blob.shape[0])
else:
bias_blob = net.params[key][1].data[...]
else:
bias_blob = net.params[key][1].data[...]
np.savetxt(os.path.join('weights', model, key+"_biases.csv"), bias_blob, delimiter=',')
if __name__ == "__main__":
main()