Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added concat support to GenArch. Fixes #61 #273

Merged
merged 1 commit into from
Jun 13, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
"start-dev": "NODE_ENV=dev node app.js",
"worker": "node ./bin/start-worker.js",
"local": "node ./bin/start-local.js",
"test": "node ./node_modules/mocha/bin/mocha --recursive test"
"test": "node ./node_modules/mocha/bin/mocha --recursive test",
"watch-test": "./node_modules/nodemon/bin/nodemon.js --exec 'node ./node_modules/mocha/bin/mocha --recursive test'"
},
"version": "0.4.0",
"dependencies": {
"dotenv": "^2.0.0",
"lodash.difference": "^4.1.2",
"nodemon": "^1.9.2",
"webgme": "^2.0.0",
"webgme-autoviz": "^2.0.3",
"webgme-breadcrumbheader": "^2.0.0",
Expand Down
124 changes: 102 additions & 22 deletions src/plugins/GenerateArchitecture/GenerateArchitecture.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ define([
* @classdesc This class represents the plugin GenerateArchitecture.
* @constructor
*/
var INDEX = '__index__';
var GenerateArchitecture = function () {
// Call base class' constructor.
PluginBase.call(this);
Expand All @@ -43,44 +44,123 @@ define([

GenerateArchitecture.prototype.main = function () {
this.LayerDict = createLayerDict(this.core, this.META);
this.uniqueId = 2;
this._oldTemplateSettings = _.templateSettings;
return PluginBase.prototype.main.apply(this, arguments);
};

GenerateArchitecture.prototype.createOutputFiles = function (tree) {
var layers = tree[Constants.CHILDREN],
//initialLayers,
result = {},
template,
snippet,
code,
args;
code;

//initialLayers = layers.filter(layer => layer[Constants.PREV].length === 0);
// Add an index to each layer
layers.forEach((l, index) => l[INDEX] = index);
code = this.genArchCode(layers);

result[tree.name + '.lua'] = code;
_.templateSettings = this._oldTemplateSettings; // FIXME: Fix this in SimpleNodes
return result;
};

code = [
GenerateArchitecture.prototype.genArchCode = function (layers) {
// Create a 'null' start layer

return [
'require \'nn\'',
'',
'local net = nn.Sequential()'
this.createSequential(layers[0], 'net').code,
'\nreturn net'
].join('\n');
};

// Start with sequential (just one input)
for (var i = 0; i < layers.length; i++) {
if (layers[i][Constants.NEXT].length > 1) {
// no support for
this.logger.error('No support for parallel layers... yet');
GenerateArchitecture.prototype.createSequential = function (layer, name) {
var next = layer[Constants.NEXT][0],
args,
template,
snippet,
snippets,
code = `\nlocal ${name} = nn.Sequential()`,

group,
i,
result;

while (layer) {
// if there is only one successor, just add the given layer
if (layer[Constants.PREV].length > 1) { // sequential layers are over
next = layer; // the given layer will be added by the caller
break;
} else {
// args
args = this.createArgString(layers[i]);
template = _.template('net:add(nn.{{= name }}' + args + ')');
snippet = template(layers[i]);
} else { // add the given layer
args = this.createArgString(layer);
template = _.template(name + ':add(nn.{{= name }}' + args + ')');
snippet = template(layer);
code += '\n' + snippet;

}
}

code += '\n\nreturn net';
while (layer && layer[Constants.NEXT].length > 1) { // concat/parallel
// if there is a fork, recurse and add a concat layer

result[tree.name + '.lua'] = code;
_.templateSettings = this._oldTemplateSettings; // FIXME: Fix this in SimpleNodes
return result;
this.logger.debug(`detected fork of size ${layer[Constants.NEXT].length}`);
snippets = layer[Constants.NEXT].map(nlayer =>
this.createSequential(nlayer, 'net_'+(this.uniqueId++)));
code += '\n' + snippets.map(snippet => snippet.code).join('\n');

// Make sure all snippets end at the same concat node

// Until all snippets end at the same concat node
snippets.sort((a, b) => a.endIndex < b.endIndex ? -1 : 1);
group = [];
while (snippets.length > 0) {
// Add snippets to the group
i = 0;
while (i < snippets.length &&
snippets[0].endIndex === snippets[i].endIndex) {

group.push(snippets[i]);
i++;
}

// Add concat layer
layer = group[0].next;
if (layer) {
args = this.createArgString(layer);
code += `\n\nlocal concat_${layer[INDEX]} = nn.Concat${args}\n` +
group.map(snippet =>
`concat_${layer[INDEX]}:add(${snippet.name})`)
.join('\n') + `\n\n${name}:add(concat_${layer[INDEX]})`;

next = layer[Constants.NEXT][0];
} else {
next = null; // no next layers
}

// Remove the updated snippets
this.logger.debug('removing ' + i + ' snippet(s)');
snippets.splice(0, i);

// merge the elements in the group
if (snippets.length) { // prepare next iteration
result = this.createSequential(next, 'net_'+(this.uniqueId++));
code += result.code;
group = [result];
this.logger.debug('updating group ('+ snippets.length+ ' left)');
}
}
}

layer = next;
next = layer && layer[Constants.NEXT][0];
}

return {
code: code,
name: name,
endIndex: next ? next[INDEX] : Infinity,
next: next
};
};

GenerateArchitecture.prototype.createArgString = function (layer) {
Expand Down
4 changes: 2 additions & 2 deletions src/seeds/devTests/devTests.webgmex
Git LFS file not shown
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ describe('GenerateArchitecture', function () {
var cases = [
['/4', 'basic.lua'],
['/T', 'basic-transfers.lua'],
['/t', 'concat-parallel.lua'],
['/w', 'googlenet.lua'],
['/W', 'overfeat.lua']
// TODO: Add more tests
// Need a concat test
// TODO
];

var runTest = function(pair, done) {
Expand Down
4 changes: 2 additions & 2 deletions test/plugins/GenerateExecFile/GenerateExecFile.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
'use strict';
var testFixture = require('../../globals');

describe('GenerateExecFile', function () {
describe.skip('GenerateExecFile', function () {
var gmeConfig = testFixture.getGmeConfig(),
expect = testFixture.expect,
logger = testFixture.logger.fork('GenerateExecFile'),
Expand All @@ -28,7 +28,7 @@ describe('GenerateExecFile', function () {
})
.then(function () {
var importParam = {
projectSeed: testFixture.path.join(testFixture.SEED_DIR, 'EmptyProject.json'),
projectSeed: testFixture.path.join(testFixture.SEED_DIR, 'EmptyProject.webgmex'),
projectName: projectName,
branchName: 'master',
logger: logger,
Expand Down
4 changes: 2 additions & 2 deletions test/plugins/ImportArtifact/ImportArtifact.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
'use strict';
var testFixture = require('../../globals');

describe('ImportArtifact', function () {
describe.skip('ImportArtifact', function () {
var gmeConfig = testFixture.getGmeConfig(),
expect = testFixture.expect,
logger = testFixture.logger.fork('ImportArtifact'),
Expand All @@ -28,7 +28,7 @@ describe('ImportArtifact', function () {
})
.then(function () {
var importParam = {
projectSeed: testFixture.path.join(testFixture.SEED_DIR, 'EmptyProject.json'),
projectSeed: testFixture.path.join(testFixture.SEED_DIR, 'EmptyProject.webgmex'),
projectName: projectName,
branchName: 'master',
logger: logger,
Expand Down
24 changes: 24 additions & 0 deletions test/test-cases/generated-code/concat-parallel.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
require 'nn'

local net = nn.Sequential()
net:add(nn.Reshape(100))

local net_2 = nn.Sequential()
net_2:add(nn.Linear(100, 150))
net_2:add(nn.Tanh())
net_2:add(nn.Linear(150, 50))

local net_3 = nn.Sequential()
net_3:add(nn.Linear(100, 150))
net_3:add(nn.Tanh())
net_3:add(nn.Linear(150, 30))

local concat_7 = nn.Concat(1)
concat_7:add(net_3)
concat_7:add(net_2)

net:add(concat_7)
net:add(nn.Tanh())
net:add(nn.Linear(80, 7))

return net
Loading