diff --git a/docs/installation.md b/docs/installation.md index 73832716817..7d5419b31ff 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -31,7 +31,7 @@ The following sections detail prerequisites and installation on Ubuntu. For OS X * Boost * MKL (but see the [boost-eigen branch](https://github.com/BVLC/caffe/tree/boost-eigen) for a boost/Eigen3 port) * OpenCV -* glog, gflags, protobuf, leveldb, snappy +* glog, gflags, protobuf, leveldb, snappy, hdf5 * For the Python wrapper: python, numpy (>= 1.7 preferred), and boost_python * For the Matlab wrapper: Matlab with mex @@ -41,7 +41,7 @@ Caffe also needs Intel MKL as the backend of its matrix computation and vectoriz You will also need other packages, most of which can be installed via apt-get using: - sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy-dev libopencv-dev libboost-all-dev + sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy-dev libopencv-dev libboost-all-dev libhdf5-dev The only exception being the google logging library, which does not exist in the Ubuntu 12.04 repository. To install it, do: @@ -81,8 +81,9 @@ Install [homebrew](http://brew.sh/) to install most of the prerequisites. Starti # install python by (1) Anaconda or (2) brew install python brew install --build-from-source boost - brew install snappy leveldb protobuf gflags glog + brew install snappy leveldb protobuf gflags glog hdf5 brew tap homebrew/science + brew install homebrew/science/hdf5 brew install homebrew/science/opencv Building boost from source is needed to link against your local python. diff --git a/tools/extra/convert_net.py b/tools/extra/convert_net.py new file mode 100644 index 00000000000..05938980e01 --- /dev/null +++ b/tools/extra/convert_net.py @@ -0,0 +1,301 @@ +import sys +sys.path.append('../../python/') +import os +from caffe.proto import caffe_pb2 +import caffe.convert +from google.protobuf import text_format +import cPickle as pickle +import numpy as np + +class CudaConvNetReader(object): + def __init__(self, net, readblobs=False, ignore_data_and_loss=True): + self.name = os.path.basename(net) + self.readblobs = readblobs + self.ignore_data_and_loss = ignore_data_and_loss + + try: + net = pickle.load(open(net)) + except ImportError: + # It wants the 'options' module from cuda-convnet + # so we fake it by creating an object whose every member + # is a class that does nothing + faker = type('fake', (), {'__getattr__': + lambda s, n: type(n, (), {})})() + sys.modules['options'] = faker + net = pickle.load(open(net)) + + # Support either the full pickled net state + # or just the layer list + if isinstance(net, dict) and 'model_state' in net: + self.net = net['model_state']['layers'] + elif isinstance(net, list): + self.net = net + else: + raise Exception("Unknown cuda-convnet net type") + + neurontypemap = {'relu': 'relu', + 'logistic': 'sigmoid', + 'dropout': 'dropout'} + + poolmethod = { + 'max': caffe_pb2.LayerParameter.MAX, + 'avg': caffe_pb2.LayerParameter.AVE + } + + def read(self): + """ + Read the cuda-convnet file and convert it to a dict that has the + same structure as a caffe protobuf + """ + layers = [] + datalayer = None + + def find_non_neuron_ancestors(layer): + """Find the upstream layers that are not neurons""" + out = [] + for l in layer.get('inputLayers', []): + if l['type'] == 'neuron': + out += find_non_neuron_ancestors(l) + else: + out += [l['name']] + return out + + for layer in self.net: + layertype = layer['type'].split('.')[0] + + if layer['name'] == 'data': + datalayer = layer + + if self.ignore_data_and_loss and layertype in ['data', 'cost']: + continue + + readfn = getattr(self, 'read_' + layertype) + + convertedlayer = readfn(layer) + + layerconnection = {} + layerconnection['layer'] = convertedlayer + + # Add the top (our output) and bottom (input) links. Neuron layers + # operate "in place" so have the same top and bottom. + layerconnection['bottom'] = find_non_neuron_ancestors(layer) + + if layer['type'] == "neuron": + layerconnection['top'] = layerconnection['bottom'] + else: + layerconnection['top'] = [layer['name']] + + + layers.append(layerconnection) + + netdict = {'name': self.name, + 'layers': layers} + + # Add the hardcoded data dimensions instead of a data layer + # will assume that the data layer is called "data" (since otherwise) + # it is not trivial to distinguish it from a label layer) + if self.ignore_data_and_loss and datalayer is not None: + netdict['input'] = ["data"] + size = int(np.sqrt(datalayer['outputs']/3)) + netdict['input_dim'] = [1, 3, size, size] + + return netdict + + def read_data(self, layer): + return {'type': 'data', + 'name': layer['name'] + } + + def read_conv(self, layer): + assert len(layer['groups']) == 1 + assert layer['filters'] % layer['groups'][0] == 0 + assert layer['sharedBiases'] == True + + newlayer = {'type': 'conv', + 'name': layer['name'], + 'num_output': layer['filters'], + 'weight_filler': {'type': 'gaussian', + 'std': layer['initW'][0]}, + 'bias_filler': {'type': 'constant', + 'value': layer['initB']}, + 'pad': -layer['padding'][0], + 'kernelsize': layer['filterSize'][0], + 'group': layer['groups'][0], + 'stride': layer['stride'][0], + } + + if self.readblobs: + # shape is ((channels/group)*filterSize*filterSize, nfilters) + # want (nfilters, channels/group, height, width) + + weights = layer['weights'][0].T + weights = weights.reshape(layer['filters'], + layer['channels'][0]/layer['groups'][0], + layer['filterSize'][0], + layer['filterSize'][0]) + + biases = layer['biases'].flatten() + biases = biases.reshape(1, 1, 1, len(biases)) + + weightsblob = caffe.convert.array_to_blobproto(weights) + biasesblob = caffe.convert.array_to_blobproto(biases) + newlayer['blobs'] = [weightsblob, biasesblob] + + return newlayer + + def read_pool(self, layer): + return {'type': 'pool', + 'name': layer['name'], + 'pool': self.poolmethod[layer['pool']], + 'kernelsize': layer['sizeX'], + 'stride': layer['stride'], + } + + def read_fc(self, layer): + newlayer = {'type': 'innerproduct', + 'name': layer['name'], + 'num_output': layer['outputs'], + 'weight_filler': {'type': 'gaussian', + 'std': layer['initW'][0]}, + 'bias_filler': {'type': 'constant', + 'value': layer['initB']}, + } + + if self.readblobs: + # shape is (ninputs, noutputs) + # want (1, 1, noutputs, ninputs) + weights = layer['weights'][0].T + weights = weights.reshape(1, 1, layer['outputs'], + layer['numInputs'][0]) + + biases = layer['biases'].flatten() + biases = biases.reshape(1, 1, 1, len(biases)) + + weightsblob = caffe.convert.array_to_blobproto(weights) + biasesblob = caffe.convert.array_to_blobproto(biases) + + newlayer['blobs'] = [weightsblob, biasesblob] + + return newlayer + + def read_softmax(self, layer): + return {'type': 'softmax', + 'name': layer['name']} + + def read_cost(self, layer): + # TODO recognise when combined with softmax and + # use softmax_loss instead + if layer['type'] == "cost.logreg": + return {'type': 'multinomial_logistic_loss', + 'name': layer['name']} + + def read_neuron(self, layer): + assert layer['neuron']['type'] in self.neurontypemap.keys() + return {'name': layer['name'], + 'type': self.neurontypemap[layer['neuron']['type']]} + + def read_cmrnorm(self, layer): + return {'name': layer['name'], + 'type': "lrn", + 'local_size': layer['size'], + # cuda-convnet sneakily divides by size when reading the + # net parameter file (layer.py:1041) so correct here + 'alpha': layer['scale'] * layer['size'], + 'beta': layer['pow'] + } + + def read_rnorm(self, layer): + # return self.read_cmrnorm(layer) + raise NotImplementedError('rnorm not implemented') + + def read_cnorm(self, layer): + raise NotImplementedError('cnorm not implemented') + + +class CudaConvNetWriter(object): + def __init__(self, net): + pass + + def write_data(self, layer): + pass + + def write_conv(self, layer): + pass + + def write_pool(self, layer): + pass + + def write_innerproduct(self, layer): + pass + + def write_softmax_loss(self, layer): + pass + + def write_softmax(self, layer): + pass + + def write_multinomial_logistic_loss(self, layer): + pass + + def write_relu(self, layer): + pass + + def write_sigmoid(self, layer): + pass + + def write_dropout(self, layer): + pass + + def write_lrn(self, layer): + pass + +def cudaconv_to_prototxt(cudanet): + """Convert the cuda-convnet layer definition to caffe prototxt. + Takes the filename of a pickled cuda-convnet snapshot and returns + a string. + """ + netdict = CudaConvNetReader(cudanet, readblobs=False).read() + protobufnet = dict_to_protobuf(netdict) + return text_format.MessageToString(protobufnet) + +def cudaconv_to_proto(cudanet): + """Convert a cuda-convnet pickled network (including weights) + to a caffe protobuffer. Takes a filename of a pickled cuda-convnet + net and returns a NetParameter protobuffer python object, + which can then be serialized with the SerializeToString() method + and written to a file. + """ + netdict = CudaConvNetReader(cudanet, readblobs=True).read() + protobufnet = dict_to_protobuf(netdict) + return protobufnet + +# adapted from https://github.com/davyzhang/dict-to-protobuf/ +def list_to_protobuf(values, message): + """parse list to protobuf message""" + if values == []: + pass + elif isinstance(values[0], dict): + #value needs to be further parsed + for val in values: + cmd = message.add() + dict_to_protobuf(val, cmd) + else: + #value can be set + message.extend(values) + +def dict_to_protobuf(values, message=None): + """convert dict to protobuf""" + if message is None: + message = caffe_pb2.NetParameter() + + for k, val in values.iteritems(): + if isinstance(val, dict): + #value needs to be further parsed + dict_to_protobuf(val, getattr(message, k)) + elif isinstance(val, list): + list_to_protobuf(val, getattr(message, k)) + else: + #value can be set + setattr(message, k, val) + + return message