-
Notifications
You must be signed in to change notification settings - Fork 0
/
NarrowTable.lua
41 lines (36 loc) · 1.12 KB
/
NarrowTable.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
local NarrowTable, parent = torch.class('nn.NarrowTable', 'nn.Module')
function NarrowTable:__init(offset, length)
parent.__init(self)
self.offset = offset
self.length = length or 1
if not offset then
error('nn.NarrowTable(offset, length)')
end
self.output = {}
self.gradInput = {}
end
function NarrowTable:updateOutput(input)
for k,v in ipairs(self.output) do self.output[k] = nil end
for i=1,self.length do
self.output[i] = input[self.offset+i-1]
end
return self.output
end
function NarrowTable:updateGradInput(input, gradOutput)
for i=1,#gradOutput do
self.gradInput[self.offset+i-1] = gradOutput[i]
end
for i=1,#input do
if (i < self.offset) or (i >= self.offset + self.length) then
self.gradInput[i] = nn.utils.recursiveResizeAs(self.gradInput[i], input[i])
nn.utils.recursiveFill(self.gradInput[i], 0)
end
end
for i=#input+1,#self.gradInput do self.gradInput[i] = nil end
return self.gradInput
end
function NarrowTable:type(type, tensorCache)
self.output = {}
self.gradInput = {}
return parent.type(self, type, tensorCache)
end