Skip to content

Commit

Permalink
Added 'getAllDefinitions' utility fn for Export:Pipeline extensions. F…
Browse files Browse the repository at this point in the history
…ixes #959 (#960)

* WIP #959 Added getAllDefinitions utility function

* WIP #959 Updated cli code to use new utility function

* WIP #959 Removed unused import
  • Loading branch information
brollb authored Jan 28, 2017
1 parent 5cedd2c commit 0b43ddf
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 144 deletions.
55 changes: 54 additions & 1 deletion src/plugins/Export/Export.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

define([
'text!./metadata.json',
'text!./deepforge.ejs',
'./format',
'plugin/PluginBase',
'deepforge/plugin/PtrCodeGen',
Expand All @@ -12,6 +13,7 @@ define([
'q'
], function (
pluginMetadata,
DeepForgeBaseCode,
FORMATS,
PluginBase,
PtrCodeGen,
Expand All @@ -28,7 +30,8 @@ define([
lineOffset: true,
code: true
},
RESERVED = /^(and|break|do|else|elseifend|false|for|function|if|in|local|nil|not|orrepeat|return|then|true|until|while|print)$/;
RESERVED = /^(and|break|do|else|elseifend|false|for|function|if|in|local|nil|not|orrepeat|return|then|true|until|while|print)$/,
DeepForgeTpl = _.template(DeepForgeBaseCode);

/**
* Initializes a new instance of Export.
Expand Down Expand Up @@ -842,5 +845,55 @@ define([

_.extend(Export.prototype, PtrCodeGen.prototype);

// Extra utilities for export types
Export.prototype.INIT_CLASSES_FN = '__init_classes';
Export.prototype.INIT_LAYERS_FN = '__init_layers';
Export.prototype.getAllDefinitions = function (sections) {
var code = [],
classes,
initClassFn,
initLayerFn;

classes = sections.orderedClasses
// Create fns from the classes
.map(name => this.indent(sections.classes[name])).join('\n');

initClassFn = [
`local function ${this.INIT_CLASSES_FN}()`,
this.indent(classes),
'end'
].join('\n');

code = code.concat(initClassFn);

// wrap the layers in a function
initLayerFn = [
`local function ${this.INIT_LAYERS_FN}()`,
this.indent(_.values(sections.layers).join('\n\n')),
'end'
].join('\n');
code = code.concat(initLayerFn);

// Add operation fn definitions
code = code.concat(_.values(sections.operations));
code = code.concat(_.values(sections.pipelines));

// define deserializers, serializers
code.push(sections.deserializers);
code.push(sections.serializers);

code.push(this.getDeepforgeObject());
code.push('deepforge.initialize()');

code.push(sections.serializeOutputsDef);
return code.join('\n\n');
};

Export.prototype.getDeepforgeObject = function (content) {
content = content || {};
content.initCode = content.initCode || `${this.INIT_CLASSES_FN}()\n${' '}${this.INIT_LAYERS_FN}()`;
return DeepForgeTpl(content);
};

return Export;
});
183 changes: 40 additions & 143 deletions src/plugins/Export/formats/cli/cli.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
/*globals define*/
// Simple torch cli for the given pipeline
define([
'underscore'
], function(
_
) {

var INIT_CLASSES_FN = '__initClasses',
INIT_LAYERS_FN = '__initLayers',
TOBOOLEAN,
DEEPFORGE_CODE; // defined at the bottom (after the embedded template)
var TOBOOLEAN;

var deserializersFromString = function(sections) {
var hasBool = false;
Expand Down Expand Up @@ -39,156 +34,62 @@ define([
};

var createExecFile = function (sections, staticInputs) {
var classes,
initClassFn,
initLayerFn,
code = [];
var code = [];

// Update deserializers for cli input
deserializersFromString.call(this, sections);

// concat all the sections into a single file
// wrap the class/layer initialization in a fn
// Add the classes ordered wrt their deps
classes = sections.orderedClasses
// Create fns from the classes
.map(name => this.indent(sections.classes[name])).join('\n');
// Define all the operations, pipelines, etc
code.push(this.getAllDefinitions(sections));

initClassFn = [
`local function ${INIT_CLASSES_FN}()`,
this.indent(classes),
'end'
].join('\n');
// Command line specific stuff
var pipelineName = Object.keys(sections.pipelines)[0],
files = {},
main,
args,
staticNames = staticInputs.map(input => input.name),
varDefs,
index = 1;

code = code.concat(initClassFn);
// Create some names for the inputs
args = sections.mainInputNames.map(name => `${sections.deserializerFor[name]}(${name})`);

// wrap the layers in a function
initLayerFn = [
`local function ${INIT_LAYERS_FN}()`,
this.indent(_.values(sections.layers).join('\n\n')),
'end'
].join('\n');
code = code.concat(initLayerFn);
main = `local outputs = ${pipelineName}(${args.join(', ')})`;

// Add operation fn definitions
code = code.concat(_.values(sections.operations));
code = code.concat(_.values(sections.pipelines));
// Grab the args from the cli
code.push(sections.mainInputNames.map((name, index) => {
return `local ${name} = arg[${index + 1}]`;
}).join('\n'));

code.push(DEEPFORGE_CODE);
code.push('deepforge.initialize()');

// define deserializers, serializers
code.push(sections.deserializers);
code.push(sections.serializers);

code.push(sections.serializeOutputsDef);

if (staticInputs.length) {
var files = {},
staticNames = staticInputs.map(input => input.name),
varDefs,
index = 1;

// Add the hash for each of the static inputs and reference them
staticInputs.forEach(input => {
files[`res/${input.name}`] = input.hash;
});

varDefs = staticNames.map(name => {
return `local ${name} = './res/${name}'`;
});

// Grab the remaining args from the cli
varDefs = varDefs.concat(sections.mainInputNames.map(name => {
if (!staticNames.includes(name)) {
return `local ${name} = arg[${index++}]`;
}
}));

// Add the main fn
code.push(varDefs.join('\n'));
code.push(sections.main);

// Save outputs to disk
code.push(sections.serializeOutputs);

files['init.lua'] = code.join('\n\n');

return files;
} else {
var pipelineName = Object.keys(sections.pipelines)[0],
main,
args;
// Add the hash for each of the static inputs and reference them
staticInputs.forEach(input => {
files[`res/${input.name}`] = input.hash;
});

// Create some names for the inputs
args = sections.mainInputNames.map(name => `${sections.deserializerFor[name]}(${name})`);
varDefs = staticNames.map(name => {
return `local ${name} = './res/${name}'`;
});

main = `local outputs = ${pipelineName}(${args.join(', ')})`;
// Grab the remaining args from the cli
varDefs = varDefs.concat(sections.mainInputNames.map(name => {
if (!staticNames.includes(name)) {
return `local ${name} = arg[${index++}]`;
}
}));

// Grab the args from the cli
code.push(sections.mainInputNames.map((name, index) => {
return `local ${name} = arg[${index + 1}]`;
}).join('\n'));
// Add the main fn
code.push(varDefs.join('\n'));
code.push(main);

// Add the main fn
code.push(main);
// Save outputs to disk
code.push(sections.serializeOutputs);

// Save outputs to disk
code.push(sections.serializeOutputs);
files['init.lua'] = code.join('\n\n');

return code.join('\n\n');
}
// if no extra assets, just return the main file
return staticInputs.length ? files : files['init.lua'];
};

var deepforgeTxt =
`-- Instantiate the deepforge object
deepforge = {}
function deepforge.initialize()
require 'nn'
require 'rnn'
<%= initCode %>
end
-- Graph support
torch.class('deepforge.Graph')
function deepforge.Graph:__init(name)
-- nop
end
torch.class('deepforge._Line')
function deepforge._Line:__init(graphId, name, opts)
-- nop
end
function deepforge._Line:add(x, y)
-- nop
end
function deepforge.Graph:line(name, opts)
return deepforge._Line(self.id, name, opts)
end
-- Image support
function deepforge.image(name, tensor)
-- nop
end
torch.class('deepforge.Image')
function deepforge.Image:__init(name, tensor)
-- nop
end
function deepforge.Image:update(tensor)
-- nop
end
function deepforge.Image:title(name)
-- nop
end`;

TOBOOLEAN =
`local function toboolean(str)
if str == 'true' then
Expand All @@ -198,9 +99,5 @@ end`;
end
end`;

DEEPFORGE_CODE = _.template(deepforgeTxt)({
initCode: `${INIT_CLASSES_FN}()\n${' '}${INIT_LAYERS_FN}()`
});

return createExecFile;
});

0 comments on commit 0b43ddf

Please sign in to comment.