-
Notifications
You must be signed in to change notification settings - Fork 1
/
DistillationCriterion.lua
105 lines (79 loc) · 2.7 KB
/
DistillationCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
local DistillationCriterion, parent = torch.class('nn.DistillationCriterion', 'nn.Criterion')
local eps = 1e-12
function DistillationCriterion:__init(weights, sizeAverage)
parent.__init(self)
if sizeAverage ~= nil then
self.sizeAverage = sizeAverage
else
self.sizeAverage = true
end
if weights ~= nil then
assert(weights:dim() == 1, "weights input should be 1-D Tensor")
self.weights = weights
end
end
function DistillationCriterion:__len()
if (self.weights) then
return #self.weights
else
return 0
end
end
function DistillationCriterion:updateOutput(input, target)
-- - log(input) * target - log(1 - input) * (1 - target)
assert( input:nElement() == target:nElement(),
"input and target size mismatch")
self.buffer = self.buffer or input.new()
local buffer = self.buffer
local weights = self.weights
local output
buffer:resizeAs(input)
if weights ~= nil and target:dim() ~= 1 then
weights = self.weights:view(1, target:size(2)):expandAs(target)
end
-- log(input) * target
buffer:add(input, eps):log()
if weights ~= nil then buffer:cmul(weights) end
output = torch.dot(target, buffer)
-- log(1 - input) * (1 - target)
buffer:mul(input, -1):add(1):add(eps):log()
if weights ~= nil then buffer:cmul(weights) end
output = output + torch.sum(buffer)
output = output - torch.dot(target, buffer)
if self.sizeAverage then
output = output / input:nElement()
end
self.output = - output
return self.output
end
function DistillationCriterion:updateGradInput(input, target)
-- - (target - input) / ( input (1 - input) )
-- The gradient is slightly incorrect:
-- It should have be divided by (input + eps) (1 - input + eps)
-- but it is divided by input (1 - input + eps) + eps
-- This modification requires less memory to be computed.
assert( input:nElement() == target:nElement(),
"input and target size mismatch")
self.buffer = self.buffer or input.new()
local buffer = self.buffer
local weights = self.weights
local gradInput = self.gradInput
if weights ~= nil and target:dim() ~= 1 then
weights = self.weights:view(1, target:size(2)):expandAs(target)
end
buffer:resizeAs(input)
-- - x ( 1 + eps -x ) + eps
buffer:add(input, -1):add(-eps):cmul(input):add(-eps)
gradInput:resizeAs(input)
-- y - x
gradInput:add(target, -1, input)
-- - (y - x) / ( x ( 1 + eps -x ) + eps )
gradInput:cdiv(buffer)
if weights ~= nil then
gradInput:cmul(weights)
end
if self.sizeAverage then
gradInput:div(target:nElement())
end
return gradInput
end