diff --git a/src/common/layer-parser.js b/src/common/layer-parser.js new file mode 100644 index 000000000..3b96b6a8d --- /dev/null +++ b/src/common/layer-parser.js @@ -0,0 +1,68 @@ +(function(root, factory){ + if(typeof define === "function" && define.amd) { + define(['./lua'], function(luajs){ + return (root.LayerParser = factory(luajs)); + }); + } else if(typeof module === "object" && module.exports) { + var luajs = require('./lua'); + module.exports = (root.LayerParser = factory(luajs)); + } +}(this, function(luajs) { + var LayerParser = {}; + + var findInitParams = function(ast){ + // Find '__init' function + var params; + ast.block.stats.forEach(function(block){ + if(block.key && block.key.val == '__init' && block.func){ + params = block.func.args; + if(params.length === 0 && block.func.varargs){ + params[0] = 'params'; + } + } + }); + return params; + }; + + var findTorchClass = function(ast){ + var torchClassArgs, // args for `torch.class(...)` + name = '', + baseType, + params = []; + + if(ast.type == 'function'){ + ast.block.stats.forEach(function(func){ + if(func.type == 'stat.local' && func.right && func.right[0] && + func.right[0].func && func.right[0].func.self && + func.right[0].func.self.val == 'torch' && + func.right[0].func.key.val == 'class'){ + + torchClassArgs = func.right[0].args.map(arg => arg.val); + name = torchClassArgs[0]; + if(name !== ''){ + name = name.replace('nn.', ''); + params = findInitParams(ast); + if (torchClassArgs.length > 1) { + baseType = torchClassArgs[1].replace('nn.', ''); + } + } + } + }); + // If there is a name, check for methods owned by the given name + // which modify a 'self' value and return 'self' + // TODO + } + return { + name, + baseType, + params + }; + }; + + LayerParser.parse = function(text) { + var ast = luajs.parser.parse(text); + return findTorchClass(ast); + }; + + return LayerParser; +})); diff --git a/utils/nn-parser.js b/utils/nn-parser.js index a0297afc1..b81a9a560 100644 --- a/utils/nn-parser.js +++ b/utils/nn-parser.js @@ -1,60 +1,20 @@ -var fs = require('fs'); -var path = require('path'); -var parser = require('../src/common/lua').parser; -var torchPath = process.env.HOME + '/torch/extra/nn/'; -var SKIP_LAYERS = {}; -var skipLayerList = require('./skipLayers.json'); -skipLayerList.forEach(name => SKIP_LAYERS[name] = true); - -var findInitParams = function(ast){ - // Find '__init' function - var params; - ast.block.stats.forEach(function(block){ - if(block.key && block.key.val == '__init' && block.func){ - params = block.func.args; - if(params.length === 0 && block.func.varargs){ - params[0] = 'params'; - } - } - }); - return params; -}; +var fs = require('fs'), + path = require('path'), + torchPath, -var findTorchClass = function(ast){ - var torchClassArgs, // args for `torch.class(...)` - name = '', - baseType, - params = []; + LayerParser = require(__dirname + '/../src/common/layer-parser'), + SKIP_LAYERS = {}, + skipLayerList = require('./skipLayers.json'), - if(ast.type == 'function'){ - ast.block.stats.forEach(function(func){ - if(func.type == 'stat.local' && func.right && func.right[0] && - func.right[0].func && func.right[0].func.self && - func.right[0].func.self.val == 'torch' && - func.right[0].func.key.val == 'class'){ + categories = require('./categories.json'), + catNames = Object.keys(categories); + layerToCategory = {}; - torchClassArgs = func.right[0].args.map(arg => arg.val); - name = torchClassArgs[0]; - if(name !== ''){ - name = name.replace('nn.', ''); - params = findInitParams(ast); - if (torchClassArgs.length > 1) { - baseType = torchClassArgs[1].replace('nn.', ''); - } - } - } - }); - } - return { - name, - baseType, - params - }; -}; +// Check the deepforge config +// FIXME +torchPath = process.env.HOME + '/torch/extra/nn/'; -var categories = require('./categories.json'); -var catNames = Object.keys(categories); -var layerToCategory = {}; +skipLayerList.forEach(name => SKIP_LAYERS[name] = true); catNames.forEach(cat => // create layer -> category dictionary categories[cat].forEach(lname => layerToCategory[lname] = cat) ); @@ -73,14 +33,12 @@ fs.readdir(torchPath, function(err,files){ layers = files.filter(filename => path.extname(filename) === '.lua') .map(filename => fs.readFileSync(torchPath + filename, 'utf8')) - .map(code => parser.parse(code)) - .map(ast => findTorchClass(ast)) // create initial layers + .map(code => LayerParser.parse(code)) .filter(layer => !!layer && layer.name); layers.forEach(layer => { layer.type = lookupType(layer.name); layerByName[layer.name] = layer; - layer.setters = []; }); // handle inheritance