-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
334 lines (293 loc) · 14.3 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math
from torch.nn.modules.distance import PairwiseDistance
from torch.autograd import Function
# Support: ['Softmax', 'ArcFace', 'CosFace', 'SphereFace', 'Am_softmax', 'Triplet']
class Softmax(nn.Module):
r"""Implement of Softmax (normal classification head):
Args:
in_features: size of each input sample
out_features: size of each output sample
device_id: the ID of GPU where the model will be trained by model parallel.
if device_id=None, it will be trained on CPU without model parallel.
"""
def __init__(self, in_features, out_features, device_id):
super(Softmax, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device_id = device_id
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zero_(self.bias)
def forward(self, x):
if self.device_id == None:
out = F.linear(x, self.weight, self.bias)
else:
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
sub_biases = torch.chunk(self.bias, len(self.device_id), dim=0)
temp_x = x.cuda(self.device_id[0])
weight = sub_weights[0].cuda(self.device_id[0])
bias = sub_biases[0].cuda(self.device_id[0])
out = F.linear(temp_x, weight, bias)
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
weight = sub_weights[i].cuda(self.device_id[i])
bias = sub_biases[i].cuda(self.device_id[i])
out = torch.cat((out, F.linear(temp_x, weight, bias).cuda(self.device_id[0])), dim=1)
return out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
class ArcFace(nn.Module):
r"""Implement of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
Args:
in_features: size of each input sample
out_features: size of each output sample
device_id: the ID of GPU where the model will be trained by model parallel.
if device_id=None, it will be trained on CPU without model parallel.
s: norm of input feature
m: margin
cos(theta+m)
"""
def __init__(self, in_features, out_features, device_id, s = 64.0, m = 0.50, easy_margin = False):
super(ArcFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device_id = device_id
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
if self.device_id == None:
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
else:
x = input
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
temp_x = x.cuda(self.device_id[0])
weight = sub_weights[0].cuda(self.device_id[0])
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
weight = sub_weights[i].cuda(self.device_id[i])
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
one_hot = torch.zeros(cosine.size())
if self.device_id != None:
one_hot = one_hot.cuda(self.device_id[0])
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
return output
class CosFace(nn.Module):
r"""Implement of CosFace (https://arxiv.org/pdf/1801.09414.pdf):
Args:
in_features: size of each input sample
out_features: size of each output sample
device_id: the ID of GPU where the model will be trained by model parallel.
if device_id=None, it will be trained on CPU without model parallel.
s: norm of input feature
m: margin
cos(theta)-m
"""
def __init__(self, in_features, out_features, device_id, s = 64.0, m = 0.35):
super(CosFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device_id = device_id
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
if self.device_id == None:
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
else:
x = input
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
temp_x = x.cuda(self.device_id[0])
weight = sub_weights[0].cuda(self.device_id[0])
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
weight = sub_weights[i].cuda(self.device_id[i])
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)
phi = cosine - self.m
# --------------------------- convert label to one-hot ---------------------------
one_hot = torch.zeros(cosine.size())
if self.device_id != None:
one_hot = one_hot.cuda(self.device_id[0])
# one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
return output
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features = ' + str(self.in_features) \
+ ', out_features = ' + str(self.out_features) \
+ ', s = ' + str(self.s) \
+ ', m = ' + str(self.m) + ')'
class SphereFace(nn.Module):
r"""Implement of SphereFace (https://arxiv.org/pdf/1704.08063.pdf):
Args:
in_features: size of each input sample
out_features: size of each output sample
device_id: the ID of GPU where the model will be trained by model parallel.
if device_id=None, it will be trained on CPU without model parallel.
m: margin
cos(m*theta)
"""
def __init__(self, in_features, out_features, device_id, m = 4):
super(SphereFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.m = m
self.base = 1000.0
self.gamma = 0.12
self.power = 1
self.LambdaMin = 5.0
self.iter = 0
self.device_id = device_id
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
# duplication formula
self.mlambda = [
lambda x: x ** 0,
lambda x: x ** 1,
lambda x: 2 * x ** 2 - 1,
lambda x: 4 * x ** 3 - 3 * x,
lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
]
def forward(self, input, label):
# lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))
self.iter += 1
self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))
# --------------------------- cos(theta) & phi(theta) ---------------------------
if self.device_id == None:
cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
else:
x = input
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
temp_x = x.cuda()#x.cuda(self.device_id[0])
weight = sub_weights[0].cuda()#sub_weights[0].cuda(self.device_id[0])
cos_theta = F.linear(F.normalize(temp_x), F.normalize(weight))
"""
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
weight = sub_weights[i].cuda(self.device_id[i])
cos_theta = torch.cat((cos_theta, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)
"""
cos_theta = cos_theta.clamp(-1, 1)
cos_m_theta = self.mlambda[self.m](cos_theta)
theta = cos_theta.data.acos()
k = (self.m * theta / 3.14159265).floor()
phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
NormOfFeature = torch.norm(input, 2, 1)
# --------------------------- convert label to one-hot ---------------------------
one_hot = torch.zeros(cos_theta.size())
if self.device_id != None:
one_hot = one_hot.cuda(self.device_id[0])
one_hot.scatter_(1, label.view(-1, 1), 1)
# --------------------------- Calculate output ---------------------------
output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta
output *= NormOfFeature.view(-1, 1)
return output
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features = ' + str(self.in_features) \
+ ', out_features = ' + str(self.out_features) \
+ ', m = ' + str(self.m) + ')'
def l2_norm(input, axis = 1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class Am_softmax(nn.Module):
r"""Implement of Am_softmax (https://arxiv.org/pdf/1801.05599.pdf):
Args:
in_features: size of each input sample
out_features: size of each output sample
device_id: the ID of GPU where the model will be trained by model parallel.
if device_id=None, it will be trained on CPU without model parallel.
m: margin
s: scale of outputs
"""
def __init__(self, in_features, out_features, device_id, m = 0.35, s = 30.0):
super(Am_softmax, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.m = m
self.s = s
self.device_id = device_id
self.kernel = Parameter(torch.Tensor(in_features, out_features))
self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) # initialize kernel
def forward(self, embbedings, label):
if self.device_id == None:
kernel_norm = l2_norm(self.kernel, axis = 0)
cos_theta = torch.mm(embbedings, kernel_norm)
else:
x = embbedings
sub_kernels = torch.chunk(self.kernel, len(self.device_id), dim=1)
temp_x = x.cuda(self.device_id[0])
kernel_norm = l2_norm(sub_kernels[0], axis = 0).cuda(self.device_id[0])
cos_theta = torch.mm(temp_x, kernel_norm)
for i in range(1, len(self.device_id)):
temp_x = x.cuda(self.device_id[i])
kernel_norm = l2_norm(sub_kernels[i], axis = 0).cuda(self.device_id[i])
cos_theta = torch.cat((cos_theta, torch.mm(temp_x, kernel_norm).cuda(self.device_id[0])), dim=1)
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
phi = cos_theta - self.m
label = label.view(-1, 1) # size=(B,1)
index = cos_theta.data * 0.0 # size=(B,Classnum)
index.scatter_(1, label.data.view(-1, 1), 1)
index = index.byte()
output = cos_theta * 1.0
output[index] = phi[index] # only change the correct predicted output
output *= self.s # scale up in order to make softmax work, first introduced in normface
return output
class TripletLoss(Function):
def __init__(self, margin):
super(TripletLoss, self).__init__()
self.margin = margin
self.pdist = PairwiseDistance(2)
def forward(self, anchor, positive, negative):
pos_dist = self.pdist.forward(anchor, positive)
neg_dist = self.pdist.forward(anchor, negative)
hinge_dist = torch.clamp(self.margin + pos_dist - neg_dist, min=0.0)
loss = torch.mean(hinge_dist)
return loss