-
Notifications
You must be signed in to change notification settings - Fork 1
/
losses.py
328 lines (291 loc) · 14.4 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# Copyright (c) Hangzhou Hikvision Digital Technology Co., Ltd. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implementation of the knowledge distillation loss.
distill the pruned model using unpruned model and cnn as teachers.
"""
import torch
from torch.nn import functional as F
from functools import partial
class DistillationLoss(torch.nn.Module):
"""
This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""
Args:
inputs: The original inputs that are fed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss
def soft_cross_entropy(predicts, targets):
student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1)
targets_prob = torch.nn.functional.softmax(targets, dim=-1)
loss_batch = torch.sum(- targets_prob * student_likelihood, dim=-1)
return loss_batch.mean()
def cal_relation_loss(student_attn_list, teacher_attn_list, Ar=1):
'''
distilling relation between input images
'''
layer_num = len(student_attn_list)
relation_loss = 0.
for student_att, teacher_att in zip(student_attn_list, teacher_attn_list):
B, N, Cs = student_att[0].shape
_, _, Ct = teacher_att[0].shape
for i in range(3):
for j in range(3):
# (B, Ar, N, Cs // Ar) @ (B, Ar, Cs // Ar, N)
# (B, Ar) + (N, N)
matrix_i = student_att[i].view(B, N, Ar, Cs//Ar).transpose(1, 2) / (Cs/Ar)**0.5
matrix_j = student_att[j].view(B, N, Ar, Cs//Ar).permute(0, 2, 3, 1)
As_ij = (matrix_i @ matrix_j)
matrix_i = teacher_att[i].view(B, N, Ar, Ct//Ar).transpose(1, 2) / (Ct/Ar)**0.5
matrix_j = teacher_att[j].view(B, N, Ar, Ct//Ar).permute(0, 2, 3, 1)
At_ij = (matrix_i @ matrix_j)
relation_loss += soft_cross_entropy(As_ij, At_ij)
return relation_loss/(9. * layer_num)
def cal_hidden_loss(student_hidden_list, teacher_hidden_list):
'''
distilling mlp features
'''
layer_num = len(student_hidden_list)
hidden_loss = 0.
for student_hidden, teacher_hidden in zip(student_hidden_list, teacher_hidden_list):
hidden_loss += torch.nn.MSELoss()(student_hidden, teacher_hidden)
return hidden_loss/layer_num
def cal_hidden_relation_loss(student_hidden_list, teacher_hidden_list):
'''
distilling relation between mlp features
'''
layer_num = len(student_hidden_list)
B, N, Cs = student_hidden_list[0].shape
_, _, Ct = teacher_hidden_list[0].shape
hidden_loss = 0.
for student_hidden, teacher_hidden in zip(student_hidden_list, teacher_hidden_list):
student_hidden = torch.nn.functional.normalize(student_hidden, dim=-1)
teacher_hidden = torch.nn.functional.normalize(teacher_hidden, dim=-1)
student_relation = student_hidden @ student_hidden.transpose(-1, -2)
teacher_relation = teacher_hidden @ teacher_hidden.transpose(-1, -2)
hidden_loss += torch.mean((student_relation - teacher_relation)**2) * 49 #Window size x Window size
return hidden_loss/layer_num
def soft_distillation(outputs, teacher_outputs, outputs_kd, teacher_outputs_kd, T=1):
'''
soft distillation loss
'''
distillation_loss_full_cls = F.kl_div(
F.log_softmax(outputs / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs.numel()
distillation_loss_full_dist = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs_kd / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
return distillation_loss_full_cls + distillation_loss_full_dist
def hard_distillation(outputs, teacher_outputs, outputs_kd, teacher_outputs_kd):
'''
hard distillation loss
'''
distillation_loss_full_cls = F.cross_entropy(outputs, teacher_outputs.argmax(dim=1))
distillation_loss_full_dist = F.cross_entropy(outputs_kd, teacher_outputs_kd.argmax(dim=1))
return distillation_loss_full_cls + distillation_loss_full_dist
def mse(outputs, teacher_outputs, outputs_kd, teacher_outputs_kd):
'''
mse distillation loss
'''
distillation_loss_full_cls = F.mse_loss(outputs, teacher_outputs)
distillation_loss_full_dist = F.mse_loss(outputs_kd, teacher_outputs_kd)
return distillation_loss_full_cls + distillation_loss_full_dist
class PrunedDistillationLoss(torch.nn.Module):
"""
This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, gt: str,
teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float,
teacher_model_full,
distillation_type_full, alpha_full, tau_full,
distillation_attn_full_im, distillation_alpha_attn_full_im,
distillation_mlp_full_im, distillation_alpha_mlp_full_im,
distillation_type_full_im, alpha_full_im):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
self.gt = gt
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
self.teacher_model_full = teacher_model_full
self.distillation_type_full = distillation_type_full
self.alpha_full = alpha_full
self.tau_full = tau_full
self.distillation_attn_full_im = distillation_attn_full_im
self.distillation_alpha_attn_full_im = distillation_alpha_attn_full_im
self.distillation_mlp_full_im = distillation_mlp_full_im
self.distillation_alpha_mlp_full_im = distillation_alpha_mlp_full_im
self.distillation_type_full_im = distillation_type_full_im
self.alpha_full_im = alpha_full_im
if self.distillation_type_full_im == 'none':
self.forward = self.forward_wo_im
else:
self.forward = self.forward_with_im_minivit
self.choose_loss()
def choose_loss(self):
if self.distillation_type_full == 'soft':
self.dist_full_loss = partial(soft_distillation, T=self.tau_full)
elif self.distillation_type_full == 'hard':
self.dist_full_loss = hard_distillation
elif self.distillation_type_full == 'mse':
self.dist_full_loss = mse
if self.distillation_type_full_im == 'mse':
self.dist_full_mlp_loss = cal_hidden_loss
elif self.distillation_type_full_im == 'rel':
self.dist_full_mlp_loss = cal_hidden_relation_loss
def forward_wo_im(self, inputs, outputs, labels):
'''
only distill the logits, without intermediate features.
'''
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if not self.gt:
base_loss = torch.zeros_like(base_loss)
if self.distillation_type == 'none':
loss_cnn = torch.zeros_like(base_loss)
else:
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss_cnn = self.alpha * distillation_loss
if self.distillation_type_full == 'none':
loss_full = torch.zeros_like(base_loss)
else:
with torch.no_grad():
teacher_outputs, teacher_outputs_kd = self.teacher_model_full(inputs)
loss_full = self.alpha_full * self.dist_full_loss(outputs, teacher_outputs, outputs_kd, teacher_outputs_kd)
return base_loss + loss_cnn + loss_full
def forward_with_im_minivit(self, inputs, outputs, labels):
'''
only distill the logits and intermediate features.
refer to minivit paper <MiniViT: Compressing Vision Transformers with Weight Multiplexing>.
'''
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd, outputs_attns, outputs_mlps = outputs
base_loss = self.base_criterion(outputs, labels)
if not self.gt:
base_loss = torch.zeros_like(base_loss)
if self.distillation_type == 'none':
loss_cnn = torch.zeros_like(base_loss)
else:
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss_cnn = self.alpha * distillation_loss
if self.distillation_type_full == 'none':
loss_full = torch.zeros_like(base_loss)
else:
with torch.no_grad():
teacher_outputs, teacher_outputs_kd, \
t_outputs_attns, t_outputs_mlps = self.teacher_model_full(inputs)
loss_full = self.alpha_full * self.dist_full_loss(outputs, teacher_outputs, outputs_kd, teacher_outputs_kd)
loss_full_im = torch.zeros_like(base_loss)
if self.distillation_type_full_im != 'none':
if self.distillation_attn_full_im:
loss_full_im += self.distillation_alpha_attn_full_im * cal_relation_loss(outputs_attns, t_outputs_attns)
if self.distillation_mlp_full_im:
loss_full_im += self.distillation_alpha_mlp_full_im * self.dist_full_mlp_loss(outputs_mlps, t_outputs_mlps)
loss_full_im = self.alpha_full_im * loss_full_im
return base_loss + loss_cnn + loss_full + loss_full_im
def forward_with_im_tinybert(self, inputs, outputs, labels):
'''
only distill the logits and intermediate features.
refer to tinybert paper <TinyBERT: Distilling BERT for Natural Language Understanding>.
'''
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd, outputs_attns, outputs_mlps = outputs
base_loss = self.base_criterion(outputs, labels)
if not self.gt:
base_loss = torch.zeros_like(base_loss)
if self.distillation_type == 'none':
loss_cnn = torch.zeros_like(base_loss)
else:
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
loss_cnn = self.alpha * distillation_loss
if self.distillation_type_full == 'none':
loss_full = torch.zeros_like(base_loss)
else:
with torch.no_grad():
teacher_outputs, teacher_outputs_kd, \
teach_full_outputs_attns, teach_full_outputs_mlps = self.teacher_model_full(inputs)
loss_full = self.alpha_full * self.dist_full_loss(outputs, teacher_outputs, outputs_kd, teacher_outputs_kd)
loss_full_im = torch.zeros_like(base_loss)
if self.distillation_type_full_im == 'mse':
for outputs_attn, teach_full_outputs_attn in zip(outputs_attns, teach_full_outputs_attns):
loss_full_im += F.mse_loss(outputs_attn, teach_full_outputs_attn)
for outputs_mlp, teach_full_outputs_mlp in zip(outputs_mlps, teach_full_outputs_mlps):
loss_full_im += F.mse_loss(outputs_mlp, teach_full_outputs_mlp)
loss_full_im = self.alpha_full_im * loss_full_im
return base_loss + loss_cnn + loss_full + loss_full_im