Skip to content

Commit

Permalink
WIP #541 Refactored nn-parser for better reusability
Browse files Browse the repository at this point in the history
  • Loading branch information
brollb committed Jul 26, 2016
1 parent 4c4b2b0 commit 626bcd3
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 56 deletions.
68 changes: 68 additions & 0 deletions src/common/layer-parser.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
(function(root, factory){
if(typeof define === "function" && define.amd) {
define(['./lua'], function(luajs){
return (root.LayerParser = factory(luajs));
});
} else if(typeof module === "object" && module.exports) {
var luajs = require('./lua');
module.exports = (root.LayerParser = factory(luajs));
}
}(this, function(luajs) {
var LayerParser = {};

var findInitParams = function(ast){
// Find '__init' function
var params;
ast.block.stats.forEach(function(block){
if(block.key && block.key.val == '__init' && block.func){
params = block.func.args;
if(params.length === 0 && block.func.varargs){
params[0] = 'params';
}
}
});
return params;
};

var findTorchClass = function(ast){
var torchClassArgs, // args for `torch.class(...)`
name = '',
baseType,
params = [];

if(ast.type == 'function'){
ast.block.stats.forEach(function(func){
if(func.type == 'stat.local' && func.right && func.right[0] &&
func.right[0].func && func.right[0].func.self &&
func.right[0].func.self.val == 'torch' &&
func.right[0].func.key.val == 'class'){

torchClassArgs = func.right[0].args.map(arg => arg.val);
name = torchClassArgs[0];
if(name !== ''){
name = name.replace('nn.', '');
params = findInitParams(ast);
if (torchClassArgs.length > 1) {
baseType = torchClassArgs[1].replace('nn.', '');
}
}
}
});
// If there is a name, check for methods owned by the given name
// which modify a 'self' value and return 'self'
// TODO
}
return {
name,
baseType,
params
};
};

LayerParser.parse = function(text) {
var ast = luajs.parser.parse(text);
return findTorchClass(ast);
};

return LayerParser;
}));
70 changes: 14 additions & 56 deletions utils/nn-parser.js
Original file line number Diff line number Diff line change
@@ -1,60 +1,20 @@
var fs = require('fs');
var path = require('path');
var parser = require('../src/common/lua').parser;
var torchPath = process.env.HOME + '/torch/extra/nn/';
var SKIP_LAYERS = {};
var skipLayerList = require('./skipLayers.json');
skipLayerList.forEach(name => SKIP_LAYERS[name] = true);

var findInitParams = function(ast){
// Find '__init' function
var params;
ast.block.stats.forEach(function(block){
if(block.key && block.key.val == '__init' && block.func){
params = block.func.args;
if(params.length === 0 && block.func.varargs){
params[0] = 'params';
}
}
});
return params;
};
var fs = require('fs'),
path = require('path'),
torchPath,

var findTorchClass = function(ast){
var torchClassArgs, // args for `torch.class(...)`
name = '',
baseType,
params = [];
LayerParser = require(__dirname + '/../src/common/layer-parser'),
SKIP_LAYERS = {},
skipLayerList = require('./skipLayers.json'),

if(ast.type == 'function'){
ast.block.stats.forEach(function(func){
if(func.type == 'stat.local' && func.right && func.right[0] &&
func.right[0].func && func.right[0].func.self &&
func.right[0].func.self.val == 'torch' &&
func.right[0].func.key.val == 'class'){
categories = require('./categories.json'),
catNames = Object.keys(categories);
layerToCategory = {};

torchClassArgs = func.right[0].args.map(arg => arg.val);
name = torchClassArgs[0];
if(name !== ''){
name = name.replace('nn.', '');
params = findInitParams(ast);
if (torchClassArgs.length > 1) {
baseType = torchClassArgs[1].replace('nn.', '');
}
}
}
});
}
return {
name,
baseType,
params
};
};
// Check the deepforge config
// FIXME
torchPath = process.env.HOME + '/torch/extra/nn/';

var categories = require('./categories.json');
var catNames = Object.keys(categories);
var layerToCategory = {};
skipLayerList.forEach(name => SKIP_LAYERS[name] = true);
catNames.forEach(cat => // create layer -> category dictionary
categories[cat].forEach(lname => layerToCategory[lname] = cat)
);
Expand All @@ -73,14 +33,12 @@ fs.readdir(torchPath, function(err,files){

layers = files.filter(filename => path.extname(filename) === '.lua')
.map(filename => fs.readFileSync(torchPath + filename, 'utf8'))
.map(code => parser.parse(code))
.map(ast => findTorchClass(ast)) // create initial layers
.map(code => LayerParser.parse(code))
.filter(layer => !!layer && layer.name);

layers.forEach(layer => {
layer.type = lookupType(layer.name);
layerByName[layer.name] = layer;
layer.setters = [];
});

// handle inheritance
Expand Down

0 comments on commit 626bcd3

Please sign in to comment.