Skip to content

Commit

Permalink
WIP #541 Added setter support in ImportTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
brollb committed Jul 26, 2016
1 parent a7bb1e9 commit d84764a
Show file tree
Hide file tree
Showing 5 changed files with 1,786 additions and 24 deletions.
37 changes: 16 additions & 21 deletions src/plugins/CreateTorchMeta/CreateTorchMeta.js
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ define([
if (attrs) { // Add the attributes
// Remove attributes not in the given list
var currentAttrs = this.core.getValidAttributeNames(node),
defVal,
rmAttrs;

rmAttrs = _.difference(currentAttrs, attrs) // old attribute names
Expand All @@ -291,15 +292,15 @@ define([
attrs.forEach((name, index) => {
desc = {};
desc.argindex = index;
desc.default = defaults.hasOwnProperty(name) ? defaults[name] : '';
this.addAttribute(name, node, desc);
defVal = defaults.hasOwnProperty(name) ? defaults[name] : '';
this.addAttribute(name, node, desc, defVal);
});

// Add the setters to the meta
Object.keys(setters).forEach(name => {
var values;
desc = setters[name];
desc.default = defaults.hasOwnProperty(name) ? defaults[name] : '';
defVal = defaults.hasOwnProperty(name) ? defaults[name] : '';
if (desc.setterType === 'const') {
values = Object.keys(desc.setterFn);
desc.isEnum = true;
Expand All @@ -308,16 +309,16 @@ define([
if (!defaults.hasOwnProperty(name) && values.length === 1) {
// there is only a method to toggle the flag to true/false,
// then the default must be the other one
desc.default = values[0] === 'true' ? false : true;
defVal = values[0] === 'true' ? false : true;
}

if (isBoolean(desc.default)) {
if (isBoolean(defVal)) {
this.logger.debug(`setting ${name} to boolean`);
desc.type = 'boolean';
}
}
}
this.addAttribute(name, node, desc);
this.addAttribute(name, node, desc, defVal);
});
}
this.logger.debug(`added ${name} to the meta`);
Expand Down Expand Up @@ -352,36 +353,30 @@ define([
};
};

CreateTorchMeta.prototype.addAttribute = function (name, node, def) {
var initial,
schema = {};

schema.type = def.type || 'string';
CreateTorchMeta.prototype.addAttribute = function (name, node, schema, defVal) {
schema.type = schema.type || 'string';
if (schema.type === 'list') { // FIXME: add support for lists
schema.type = 'string';
}

if (def.min !== undefined) {
schema.min = +def.min;
if (schema.min !== undefined) {
schema.min = +schema.min;
}

if (def.max !== undefined) {
if (schema.max !== undefined) {
// Set the min, max
schema.max = +def.max;
schema.max = +schema.max;
}

// Add the argindex flag
schema.argindex = def.argindex;
schema.argindex = schema.argindex;

// Create the attribute and set the schema
this.core.setAttributeMeta(node, name, schema);

// Determine a default value
initial = def.hasOwnProperty('default') ? def.default : def.min || null;
if (schema.type === 'boolean') {
initial = initial !== null ? initial : false;
if (defVal) {
this.core.setAttribute(node, name, defVal);
}
this.core.setAttribute(node, name, initial);
};

return CreateTorchMeta;
Expand Down
6 changes: 3 additions & 3 deletions src/plugins/ImportTorch/nn.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ define([

Layer.prototype._setAttribute = function(name, self, value) {
var node = this._node();
debugger;
logger.info(`Setting ${name} to ${value}`);
core.setAttribute(node, name, value);
return self;
};

// Each container will have `inputs` and `outputs`
Expand Down Expand Up @@ -191,7 +191,7 @@ define([
for (var i = vals.length; i--;) {
fn = desc.setterFn[vals[i]];
value = getValue(vals[i]);
table.set(fn, layer._setAttribute.bind(layer, attr, value));
table.set(fn, layer._setAttribute.bind(layer, attr, table, value));
}
}
};
Expand All @@ -200,7 +200,7 @@ define([
var res = luajs.newContext()._G,
attrs = [].slice.call(arguments, 1),
ltGet = luajs.types.LuaTable.prototype.get,
setters = {},
setters = [],
args = [],
node;

Expand Down
1 change: 1 addition & 0 deletions test/plugins/ImportTorch/ImportTorch.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var testFixture = require('../../globals'),
'basic4.lua'
],
ONLY_TESTS = [
'googlenet-setters.lua',
'vgg.lua'
];

Expand Down
1 change: 1 addition & 0 deletions test/test-cases/code/googlenet-setters.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-- Copy of googlenet.lua which uses setters (the other googlenet has them removed)
require 'nn'
nGPU = 10
nClasses = 1000
local function inception(input_size, config)
local concat = nn.Concat(2)
if config[1][1] ~= 0 then
Expand Down
Loading

0 comments on commit d84764a

Please sign in to comment.