Skip to content

Commit

Permalink
Only retrieve member registry from nodes in set. Fixes #998 (#999)
Browse files Browse the repository at this point in the history
* Only retrieve member registry from nodes in set. Fixes #998

* WIP #998 Added test case
  • Loading branch information
brollb authored Mar 31, 2017
1 parent c8e9dfd commit e53b684
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/plugins/GenerateArchitecture/GenerateArchitecture.js
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ define([
var args = this.createArgString(layer),
def = `nn.${layer.name}${args}`,
type = layer.base.base.name,
addedIds,
memberIds,
node,
name,
children,
Expand All @@ -141,28 +141,32 @@ define([
// each nested architecture's code to the given container
if (type === 'Container') {
// Get the members of the 'addLayers' set
addedIds = {};
memberIds = {};
id = layer[SimpleNodeConstants.NODE_PATH];
node = this._nodeCache[id];
this.core.getMemberPaths(node, Constants.CONTAINED_LAYER_SET)
.forEach(id => addedIds[id] = true);
.forEach(id => memberIds[id] = true);

// Get the (sorted) children
children = layer[SimpleNodeConstants.CHILDREN]
.map(child => { // get (child, index) tuples
var index;
var index = null;

id = child[SimpleNodeConstants.NODE_PATH];
index = this.core.getMemberRegistry(node, Constants.CONTAINED_LAYER_SET, id, Constants.CONTAINED_LAYER_INDEX);
if (memberIds[id]) {
index = this.core.getMemberRegistry(node,
Constants.CONTAINED_LAYER_SET, id, Constants.CONTAINED_LAYER_INDEX);
}
return [child, index];
})
.filter(pair => pair[1] !== undefined) // remove non-members
.filter(pair => pair[1] !== null) // remove non-members
.sort((a, b) => a[1] < b[1] ? -1 : 1) // sort by 'index'
.map(pair => pair[0]);


var addedLayerDefs = '',
firstLayer;

for (var i = 0; i < children.length; i++) {
id = children[i][SimpleNodeConstants.NODE_PATH];
// Get the children!
Expand Down
Binary file modified src/seeds/devTests/devTests.webgmex
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ describe('GenerateArchitecture', function () {
//['/8', 'basic-transfers.lua'],
//['/M', 'concat-parallel.lua'],
['/Q', 'basiccontainer.lua'],
['/P', 'ContainerWithLayerArgs.lua'],
['/4', 'requiredOmitted.lua'],
['/e', 'googlenet.lua'],
['/X', 'overfeat.lua']
Expand Down
13 changes: 13 additions & 0 deletions test/test-cases/generated-code/ContainerWithLayerArgs.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
require 'nn'
require 'rnn'

local m = nn.Sequential()
m:add(nn.Linear(100, 50))
m:add(nn.LeakyReLU())
m:add(nn.Linear(50, 10))


local net = nn.Sequential()
net:add(nn.Bottle(m, 2))

return net

0 comments on commit e53b684

Please sign in to comment.