Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added meaningful error if retrieving nil val from layer. Fixes #386 #442

Merged
merged 1 commit into from
Jul 1, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/lua.js
Original file line number Diff line number Diff line change
Expand Up @@ -2687,6 +2687,7 @@ function LuaContext(){
}
exports.stdlib(_G, helpers)();
}
this.__helpers = helpers;
}

LuaContext.prototype = {}
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/ImportTorch/ImportTorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ define([
ImportTorch.prototype.loadNNMock = function () {
// This needs a refactor...
// createNN(this)
var lib = createNNSearcher(this).bind(this.context);
var lib = createNNSearcher(this, this.context).bind(this.context);

// Create a "searcher" to allow this 'nn' to be in the lib path
this.context._G.get('package').set('searchers', [function(name) {
Expand Down
26 changes: 24 additions & 2 deletions src/plugins/ImportTorch/nn.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,23 @@ define([
) {
'use strict';

var createSearcher = function(plugin) {
var createSearcher = function(plugin, context) {
var core = plugin.core,
META = plugin.META,
logger = plugin.logger.fork('nn'),
parent = plugin.tgtNode,
LayerDict = createLayerDict(core, META);
LayerDict = createLayerDict(core, META),
helpers = context.__helpers,
oldSet = helpers.__set,
isSetting = false;

// Override the helper's '__set' method to detect
// if the code is in the middle of a "set".
helpers.__set = function() {
isSetting = true;
oldSet.apply(this, arguments);
isSetting = false;
};

var connect = function(src, dst) {
var conn = core.createNode({
Expand Down Expand Up @@ -145,6 +156,7 @@ define([
var CreateLayer = function(type) {
var res = luajs.newContext()._G,
attrs = [].slice.call(arguments, 1),
ltGet = luajs.types.LuaTable.prototype.get,
node;

if (LAYERS[type]) {
Expand All @@ -165,6 +177,16 @@ define([
}
}
}

// Override get
res.get = function noNilGet(value) {
var result = ltGet.call(this, value);
if (!result && !isSetting) {
throw Error(`"${value}" is not supported for ${type}`);
}
return result;
};

return res;
};

Expand Down
83 changes: 83 additions & 0 deletions test/test-cases/code/googlenet-setters.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
-- Copy of googlenet.lua which uses setters (the other googlenet has them removed)
require 'nn'
nGPU = 10
local function inception(input_size, config)
local concat = nn.Concat(2)
if config[1][1] ~= 0 then
local conv1 = nn.Sequential()
conv1:add(nn.SpatialConvolution(input_size, config[1][1],1,1,1,1)):add(nn.ReLU(true))
concat:add(conv1)
end

local conv3 = nn.Sequential()
conv3:add(nn.SpatialConvolution( input_size, config[2][1],1,1,1,1)):add(nn.ReLU(true))
conv3:add(nn.SpatialConvolution(config[2][1], config[2][2],3,3,1,1,1,1)):add(nn.ReLU(true))
concat:add(conv3)

local conv3xx = nn.Sequential()
conv3xx:add(nn.SpatialConvolution( input_size, config[3][1],1,1,1,1)):add(nn.ReLU(true))
conv3xx:add(nn.SpatialConvolution(config[3][1], config[3][2],3,3,1,1,1,1)):add(nn.ReLU(true))
conv3xx:add(nn.SpatialConvolution(config[3][2], config[3][2],3,3,1,1,1,1)):add(nn.ReLU(true))
concat:add(conv3xx)

local pool = nn.Sequential()
pool:add(nn.SpatialZeroPadding(1,1,1,1)) -- remove after getting nn R2 into fbcode
if config[4][1] == 'max' then
pool:add(nn.SpatialMaxPooling(3,3,1,1):ceil())
elseif config[4][1] == 'avg' then
pool:add(nn.SpatialAveragePooling(3,3,1,1):ceil())
else
error('Unknown pooling')
end
if config[4][2] ~= 0 then
pool:add(nn.SpatialConvolution(input_size, config[4][2],1,1,1,1)):add(nn.ReLU(true))
end
concat:add(pool)

return concat
end

local features = nn.Sequential()
features:add(nn.SpatialConvolution(3,64,7,7,2,2,3,3)):add(nn.ReLU(true))
features:add(nn.SpatialMaxPooling(3,3,2,2):ceil())
features:add(nn.SpatialConvolution(64,64,1,1)):add(nn.ReLU(true))
features:add(nn.SpatialConvolution(64,192,3,3,1,1,1,1)):add(nn.ReLU(true))
features:add(nn.SpatialMaxPooling(3,3,2,2):ceil())
features:add(inception( 192, {{ 64},{ 64, 64},{ 64, 96},{'avg', 32}})) -- 3(a)
features:add(inception( 256, {{ 64},{ 64, 96},{ 64, 96},{'avg', 64}})) -- 3(b)
features:add(inception( 320, {{ 0},{128,160},{ 64, 96},{'max', 0}})) -- 3(c)
features:add(nn.SpatialConvolution(576,576,2,2,2,2))
features:add(inception( 576, {{224},{ 64, 96},{ 96,128},{'avg',128}})) -- 4(a)
features:add(inception( 576, {{192},{ 96,128},{ 96,128},{'avg',128}})) -- 4(b)
features:add(inception( 576, {{160},{128,160},{128,160},{'avg', 96}})) -- 4(c)
features:add(inception( 576, {{ 96},{128,192},{160,192},{'avg', 96}})) -- 4(d)

local main_branch = nn.Sequential()
main_branch:add(inception( 576, {{ 0},{128,192},{192,256},{'max', 0}})) -- 4(e)
main_branch:add(nn.SpatialConvolution(1024,1024,2,2,2,2))
main_branch:add(inception(1024, {{352},{192,320},{160,224},{'avg',128}})) -- 5(a)
main_branch:add(inception(1024, {{352},{192,320},{192,224},{'max',128}})) -- 5(b)
main_branch:add(nn.SpatialAveragePooling(7,7,1,1))
main_branch:add(nn.View(1024):setNumInputDims(3))
main_branch:add(nn.Linear(1024,nClasses))
main_branch:add(nn.LogSoftMax())

-- add auxillary classifier here (thanks to Christian Szegedy for the details)
local aux_classifier = nn.Sequential()
aux_classifier:add(nn.SpatialAveragePooling(5,5,3,3):ceil())
aux_classifier:add(nn.SpatialConvolution(576,128,1,1,1,1))
aux_classifier:add(nn.View(128*4*4):setNumInputDims(3))
aux_classifier:add(nn.Linear(128*4*4,768))
aux_classifier:add(nn.ReLU())
aux_classifier:add(nn.Linear(768,nClasses))
aux_classifier:add(nn.LogSoftMax())

local splitter = nn.Concat(2)
splitter:add(main_branch):add(aux_classifier)
local model = nn.Sequential():add(features):add(splitter)

model.imageSize = 256
model.imageCrop = 224


return model