Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Tools] Caffe Converter #229

Merged
merged 6 commits into from
Oct 8, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tools/caffe_converter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Convert Caffe Model to Mxnet Format

## Introduction

This is an experimental tool for conversion of Caffe model into mxnet model. There are several limitations to note:
* Please first make sure that there is corresponding operator in mxnet before conversion.
* The tool only supports single input and single output network.
* The tool can only work with the L2LayerParameter in Caffe. For older version, please use the ```upgrade_net_proto_binary``` and ```upgrade_net_proto_text``` in ```tools``` folder of Caffe to upgrate them.

We have verified the results of VGG_16 model and BVLC_googlenet results from Caffe model zoo.

## Notes on Codes
* The core function for converting symbol is in ```convert_symbols.py```. ```proto2script``` converts the prototxt to corresponding python script to generate the symbol. Therefore if you need to modify the auto-generated symbols, you can print out the return value. You can also find the supported layers/operators there.
* The weights are converted in ```convert_model.py```.

## Usage
Run ```python convert_model.py caffe_prototxt caffe_model save_model_name``` to convert the models. Run with ```-h``` for more details of parameters.

60 changes: 60 additions & 0 deletions tools/caffe_converter/convert_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import mxnet as mx
import caffe
import argparse
from convert_symbol import proto2symbol

def main():
parser = argparse.ArgumentParser(description='Caffe prototxt to mxnet model parameter converter.\
Note that only basic functions are implemented. You are welcomed to contribute to this file.')
parser.add_argument('caffe_prototxt', help='The prototxt file in Caffe format')
parser.add_argument('caffe_model', help='The binary model parameter file in Caffe format')
parser.add_argument('save_model_name', help='The name of the output model prefix')
args = parser.parse_args()

prob = proto2symbol(args.caffe_prototxt)
caffe.set_mode_cpu()
net_caffe = caffe.Net(args.caffe_prototxt, args.caffe_model, caffe.TEST)
arg_shapes, output_shapes, aux_shapes = prob.infer_shape(data=(1,3,224,224))
arg_names = prob.list_arguments()
arg_shape_dic = dict(zip(arg_names, arg_shapes))
arg_params = {}

first_conv = True
layer_names = net_caffe._layer_names
for layer_idx, layer in enumerate(net_caffe.layers):
layer_name = layer_names[layer_idx].replace('/', '_')
if layer.type == 'Convolution' or layer.type == 'InnerProduct':
assert(len(layer.blobs) == 2)
wmat = layer.blobs[0].data
bias = layer.blobs[1].data
if first_conv:
print 'Swapping BGR of caffe into RGB in mxnet'
wmat[:, [0, 2], :, :] = wmat[:, [2, 0], :, :]

assert(wmat.flags['C_CONTIGUOUS'] is True)
assert(bias.flags['C_CONTIGUOUS'] is True)
print 'converting layer {0}, wmat shape = {1}, bias shape = {2}'.format(layer_name, wmat.shape, bias.shape)
wmat = wmat.reshape((wmat.shape[0], -1))
bias = bias.reshape((bias.shape[0], 1))
weight_name = layer_name + "_weight"
bias_name = layer_name + "_bias"

wmat = wmat.reshape(arg_shape_dic[weight_name])
arg_params[weight_name] = mx.nd.zeros(wmat.shape)
arg_params[weight_name][:] = wmat

bias = bias.reshape(arg_shape_dic[bias_name])
arg_params[bias_name] = mx.nd.zeros(bias.shape)
arg_params[bias_name][:] = bias

if first_conv and layer.type == 'Convolution':
first_conv = False

model = mx.model.FeedForward(ctx=mx.cpu(), symbol=prob,
arg_params=arg_params, aux_params={}, num_round=1,
learning_rate=0.05, momentum=0.9, wd=0.0001)

model.save(args.save_model_name)

if __name__ == '__main__':
main()
123 changes: 123 additions & 0 deletions tools/caffe_converter/convert_symbol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import caffe
from caffe.proto import caffe_pb2
from google.protobuf import text_format
import argparse

def readProtoSolverFile(filepath):
solver_config = caffe.proto.caffe_pb2.NetParameter()
return readProtoFile(filepath, solver_config)

def readProtoFile(filepath, parser_object):
file = open(filepath, "r")
if not file:
raise self.ProcessException("ERROR (" + filepath + ")!")
text_format.Merge(str(file.read()), parser_object)
file.close()
return parser_object

def proto2script(proto_file):
proto = readProtoSolverFile(proto_file)
connection = dict()
symbols = dict()
top = dict()
flatten_count = 0
symbol_string = ""
layer = proto.layer

# We assume the first bottom blob of first layer is the output from data layer
input_name = layer[0].bottom[0]
output_name = ""
mapping = {input_name : 'data'}
need_flatten = {input_name : False}
for i in range(len(layer)):
type_string = ''
param_string = ''
name = layer[i].name.replace('/', '_')
if layer[i].type == 'Convolution':
type_string = 'mx.symbol.Convolution'
param = layer[i].convolution_param
pad = 0 if len(param.pad) == 0 else param.pad[0]
stride = 1 if len(param.stride) == 0 else param.stride[0]
param_string = "num_filter=%d, pad=(%d,%d), kernel=(%d,%d), stride=(%d,%d), no_bias=%s" %\
(param.num_output, pad, pad, param.kernel_size[0],\
param.kernel_size[0], stride, stride, not param.bias_term)
need_flatten[name] = True
if layer[i].type == 'Pooling':
type_string = 'mx.symbol.Pooling'
param = layer[i].pooling_param
param_string = "pad=(%d,%d), kernel=(%d,%d), stride=(%d,%d)" %\
(param.pad, param.pad, param.kernel_size,\
param.kernel_size, param.stride, param.stride)
if param.pool == 0:
param_string = param_string + ", pool_type='max'"
elif param.pool == 1:
param_string = param_string + ", pool_type='avg'"
else:
raise Exception("Unknown Pooling Method!")
need_flatten[name] = True
if layer[i].type == 'ReLU':
type_string = 'mx.symbol.Activation'
param_string = "act_type='relu'"
need_flatten[name] = need_flatten[mapping[proto.layer[i].bottom[0]]]
if layer[i].type == 'LRN':
type_string = 'mx.symbol.LRN'
param = layer[i].lrn_param
param_string = "alpha=%f, beta=%f, knorm=%f, nsize=%d" %\
(param.alpha, param.beta, param.k, param.local_size)
need_flatten[name] = True
if layer[i].type == 'InnerProduct':
type_string = 'mx.symbol.FullyConnected'
param = layer[i].inner_product_param
param_string = "num_hidden=%d, no_bias=%s" % (param.num_output, not param.bias_term)
need_flatten[name] = False
if layer[i].type == 'Dropout':
type_string = 'mx.symbol.Dropout'
param = layer[i].dropout_param
param_string = "p=%f" % param.dropout_ratio
need_flatten[name] = need_flatten[mapping[proto.layer[i].bottom[0]]]
if layer[i].type == 'Softmax':
type_string = 'mx.symbol.Softmax'

# We only support single output network for now.
output_name = name
if layer[i].type == 'Flatten':
type_string = 'mx.symbol.Flatten'
need_flatten[name] = False
if layer[i].type == 'Split':
type_string = 'split'
if layer[i].type == 'Concat':
type_string = 'mx.symbol.Concat'
need_flatten[name] = True
if type_string == '':
raise Exception('Unknown Layer %s!' % layer[i].type)

if type_string != 'split':
bottom = layer[i].bottom
if param_string != "":
param_string = ", " + param_string
if len(bottom) == 1:
if need_flatten[mapping[bottom[0]]] and type_string == 'mx.symbol.FullyConnected':
flatten_name = "flatten_%d" % flatten_count
symbol_string += "%s=mx.symbol.Flatten(name='%s', data=%s)\n" %\
(flatten_name, flatten_name, mapping[bottom[0]])
flatten_count += 1
need_flatten[flatten_name] = False
bottom[0] = flatten_name
mapping[bottom[0]] = bottom[0]
symbol_string += "%s = %s(name='%s', data=%s %s)\n" %\
(name, type_string, name, mapping[bottom[0]], param_string)
else:
symbol_string += "%s = %s(name='%s', *[%s] %s)\n" %\
(name, type_string, name, ','.join([mapping[x] for x in bottom]), param_string)
for j in range(len(layer[i].top)):
mapping[layer[i].top[j]] = name
return symbol_string, output_name

def proto2symbol(proto_file):
sym, output_name = proto2script(proto_file)
sym = "import mxnet as mx\n" \
+ "data = mx.symbol.Variable(name='data')\n" \
+ sym
exec(sym)
exec("ret = " + output_name)
return ret