-
Notifications
You must be signed in to change notification settings - Fork 16
/
ge2e.py
106 lines (91 loc) · 4.01 KB
/
ge2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
import torch.nn.functional as F
class GE2ELoss(nn.Module):
def __init__(self, init_w=10.0, init_b=-5.0, loss_method='softmax'):
'''
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
Accepts an input of size (N, M, D)
where N is the number of speakers in the batch,
M is the number of utterances per speaker,
and D is the dimensionality of the embedding vector (e.g. d-vector)
Args:
- init_w (float): defines the initial value of w in Equation (5) of [1]
- init_b (float): definies the initial value of b in Equation (5) of [1]
'''
super(GE2ELoss, self).__init__()
self.w = nn.Parameter(torch.tensor(init_w))
self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method
assert self.loss_method in ['softmax', 'contrast']
if self.loss_method == 'softmax':
self.embed_loss = self.embed_loss_softmax
if self.loss_method == 'contrast':
self.embed_loss = self.embed_loss_contrast
def calc_new_centroids(self, dvecs, centroids, spkr, utt):
'''
Calculates the new centroids excluding the reference utterance
'''
excl = torch.cat((dvecs[spkr,:utt], dvecs[spkr,utt+1:]))
excl = torch.mean(excl, 0)
new_centroids = []
for i, centroid in enumerate(centroids):
if i == spkr:
new_centroids.append(excl)
else:
new_centroids.append(centroid)
return torch.stack(new_centroids)
def calc_cosine_sim(self, dvecs, centroids):
'''
Make the cosine similarity matrix with dims (N,M,N)
'''
cos_sim_matrix = []
for spkr_idx, speaker in enumerate(dvecs):
cs_row = []
for utt_idx, utterance in enumerate(speaker):
new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx)
# vector based cosine similarity for speed
cs_row.append(torch.clamp(torch.mm(utterance.unsqueeze(1).transpose(0,1), new_centroids.transpose(0,1)) / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), 1e-6))
cs_row = torch.cat(cs_row, dim=0)
cos_sim_matrix.append(cs_row)
return torch.stack(cos_sim_matrix)
def embed_loss_softmax(self, dvecs, cos_sim_matrix):
'''
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
'''
N, M, _ = dvecs.shape
L = []
for j in range(N):
L_row = []
for i in range(M):
L_row.append(-F.log_softmax(cos_sim_matrix[j,i], 0)[j])
L_row = torch.stack(L_row)
L.append(L_row)
return torch.stack(L)
def embed_loss_contrast(self, dvecs, cos_sim_matrix):
'''
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
'''
N, M, _ = dvecs.shape
L = []
for j in range(N):
L_row = []
for i in range(M):
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j,i])
excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j+1:]))
L_row.append(1. - torch.sigmoid(cos_sim_matrix[j,i,j]) + torch.max(excl_centroids_sigmoids))
L_row = torch.stack(L_row)
L.append(L_row)
return torch.stack(L)
def forward(self, dvecs):
'''
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
'''
#Calculate centroids
centroids = torch.mean(dvecs, 1)
#Calculate the cosine similarity matrix
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix * self.w + self.b
L = self.embed_loss(dvecs, cos_sim_matrix)
return L.sum()