-
Notifications
You must be signed in to change notification settings - Fork 966
/
Constant.lua
36 lines (33 loc) · 1.36 KB
/
Constant.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
------------------------------------------------------------------------
--[[ Constant ]]--
-- Outputs a constant value given an input.
-- If nInputDim is specified, uses the input to determine the size of
-- the batch. The value is then replicated over the batch.
-- You can use this with nn.ConcatTable() to append constant inputs to
-- an input : nn.ConcatTable():add(nn.Constant(v)):add(nn.Identity()) .
------------------------------------------------------------------------
local Constant, parent = torch.class("nn.Constant", "nn.Module")
function Constant:__init(value, nInputDim)
self.value = value
if torch.type(self.value) == 'number' then
self.value = torch.Tensor{self.value}
end
assert(torch.isTensor(self.value), "Expecting number or tensor at arg 1")
self.nInputDim = nInputDim
parent.__init(self)
end
function Constant:updateOutput(input)
if self.nInputDim and input:dim() > self.nInputDim then
local vsize = self.value:size():totable()
self.output:resize(input:size(1), table.unpack(vsize))
local value = self.value:view(1, table.unpack(vsize))
self.output:copy(value:expand(self.output:size()))
else
self.output:resize(self.value:size()):copy(self.value)
end
return self.output
end
function Constant:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(input):zero()
return self.gradInput
end