Skip to content

Commit

Permalink
Merge branch '5-incorporate-alans-codebase-updates' into 'master'
Browse files Browse the repository at this point in the history
added comments in training script to clarify how to train network from...

Closes #5

See merge request speech/xvectors!6
  • Loading branch information
Kiran Karra committed Apr 23, 2021
2 parents a31e4e3 + 61bc86a commit fc9b498
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 71 deletions.
15 changes: 6 additions & 9 deletions scripts/train_from_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def train_plda(args, model, device, train_loader):

# compute model output and update PLDA
x, y, z, output, w = model(data, embedding_only=True)
model.PLDA.update_plda(y, target)
model.plda.update_plda(y, target)

logger.info("PLDA training epoch, count range %.2f to %.2f" % (model.PLDA.counts.min(), model.PLDA.counts.max()))
logger.info("PLDA training epoch, count range %.2f to %.2f" % (model.plda.counts.min(),model.plda.counts.max()))


def main():
Expand Down Expand Up @@ -241,7 +241,9 @@ def main():
help='training CE boost margin (default: 0)')
parser.add_argument('--ResNet', action='store_true', default=False,
help='ResNet instead of TDNN (default False)')

parser.add_argument('--vb_flag', action='store_true',
help='use VB instead of leave-one-out (default False)')

args = parser.parse_args()

use_cuda = not args.no_cuda and torch.cuda.is_available()
Expand Down Expand Up @@ -321,6 +323,7 @@ def main():
'fixed_N': args.fixed_N,
'r': args.enroll_R,
'enroll_type': args.enroll_type,
'loo_flag': not args.vb_flag,
'length_norm': args.length_norm,
'resnet_flag': args.ResNet
}
Expand Down Expand Up @@ -421,16 +424,10 @@ def main():

# Initial training: no scheduler or validation
if start_epoch <= args.init_epochs:
if init_scheduler is None and start_epoch == 1:
model.loo_flag = False # Cold start can't use leave-one-out
logger.info(" turning off leave-one-out for initialization")
logger.info("Starting initializer training from epoch %d for %d epochs", start_epoch, args.init_epochs)
for epoch in range(start_epoch, args.init_epochs + 1):
# train an epoch
train(args, model, device, train_loader, init_optimizer, epoch, args.train_cost, args.train_boost)
if not model.loo_flag:
model.loo_flag = True
logger.info(" turning leave-one-out back on")

# step learning rate
if init_scheduler is not None:
Expand Down
8 changes: 7 additions & 1 deletion scripts/train_nb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ DATA_COPY=0

# Options
MODEL_OPTS="--feature-dim=64 --embedding-dim=128 --layer-dim=768 --length_norm"
FRAME_OPTS="--random-frame-size --min-frames=200 --max-frames=200"
# NOTE:
# 1. To refine/optimize the DNN for a specific frame-length, modify the min-frames and max-frames
# For example: to refine/optime to 2s segment length, set --min-frames=200 --max-frames=200
# 2. For general training, it may be more advantageous to set a wider range for --min-frames and
# --max-frames, to allow the DNN to be trained for a more general purpose scenario
# For general training, we set --min-frames=100 and --max-frames=250
FRAME_OPTS="--random-frame-size --min-frames=100 --max-frames=250"

TRAIN_OPTS="--LLtype=Gauss --train_cost=CE --batch-size=512"

Expand Down
9 changes: 7 additions & 2 deletions scripts/train_wb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@ DATA_COPY=0

# Options
MODEL_OPTS=" --feature-dim=80 --embedding-dim=128 --ResNet --length_norm "

FRAME_OPTS=" --random-frame-size --min-frames=200 --max-frames=200 "
# NOTE:
# 1. To refine/optimize the DNN for a specific frame-length, modify the min-frames and max-frames
# For example: to refine/optime to 2s segment length, set --min-frames=200 --max-frames=200
# 2. For general training, it may be more advantageous to set a wider range for --min-frames and
# --max-frames, to allow the DNN to be trained for a more general purpose scenario
# For general training, we set --min-frames=100 and --max-frames=250
FRAME_OPTS=" --random-frame-size --min-frames=100 --max-frames=250 "

# LLtype can be None, xvec, linear, Gauss_discr, or Gauss
LLTYPE="Gauss"
Expand Down
137 changes: 80 additions & 57 deletions xvectors/plda_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,22 @@ def update_counts(x1, labels1, sums1, counts1, N0, rand_flag=False):
return sums1, counts1


# Function to compute counts within a set
def compute_counts(x, labels):
d = x.shape[1]
classes = list(set(labels.tolist()))
M = len(classes)
sums = x.new_zeros(M, d)
counts = x.new_zeros(M, )
for n in range(M):
m = classes[n]
ind = (labels == m)
sums[m, :] = x[ind, :].sum(dim=0)
counts[m] = ind.sum()

return sums, counts


class GaussLinear(nn.Module):
def __init__(self, embedding_dim, num_classes, N0=9, fixed_N=True, discr_mean=False):
super(GaussLinear, self).__init__()
Expand Down Expand Up @@ -249,7 +265,7 @@ def mean_loss(self):


class GaussQuadratic(nn.Module):
def __init__(self, embedding_dim, num_classes, N0=9, fixed_N=True, r=0.9, enroll_type='Bayes', N_dict={}, OOS=True):
def __init__(self, embedding_dim, num_classes, N0=9, fixed_N=True, r=0.9, enroll_type='Bayes', loo_flag=True, OOS=False):
super(GaussQuadratic, self).__init__()

self.num_classes = num_classes
Expand All @@ -262,7 +278,7 @@ def __init__(self, embedding_dim, num_classes, N0=9, fixed_N=True, r=0.9, enroll
self.register_buffer('sums', torch.zeros(num_classes,embedding_dim))
#self.register_buffer('counts', torch.zeros(num_classes,))
self.register_buffer('counts', torch.ones(num_classes,))
self.N_dict = N_dict
self.loo_flag = loo_flag
self.OOS = OOS

# Initialize stats
Expand All @@ -273,12 +289,13 @@ def forward(self, x, w=None, Ulda=None, d_wc=None, d_ac=None):

# Update models and compute Gaussian log-likelihoods
with torch.no_grad():
self.means.data, self.cov.data = gmm_adapt(self.counts, torch.mm(self.sums,Ulda), d_wc, d_ac, self.r, self.enroll_type, self.N_dict)
self.means.data, self.cov.data = gmm_adapt(self.counts, torch.mm(self.sums,Ulda), d_wc, d_ac, self.r, self.enroll_type)

N = x.shape[0]
M = self.num_classes
vb_flag = not self.loo_flag
if w is None:
LL = gmm_score(x, self.means, self.cov+d_wc[None,:])
LL = gmm_score_bayes(x, self.means, self.cov, d_wc[None,:], vb_flag)

# Subtract open set LL
if self.OOS:
Expand All @@ -287,7 +304,7 @@ def forward(self, x, w=None, Ulda=None, d_wc=None, d_ac=None):
# confidence weighting
LL = x.new_zeros((N,M))
for n in range(N):
LL[n,:] = gmm_score(x[n,None], self.means, self.cov+w[None,n])
LL[n,:] = gmm_score_bayes(x[n,None], self.means, self.cov, w[None,n], vb_flag)
if self.OOS:
LL[n,:] -= gmm_score(x, 0.0, d_ac[None,:]+w[None,n])

Expand Down Expand Up @@ -317,9 +334,9 @@ def compute_loss(x, y, output, w, labels, loss_type='CE', model=None, boost=0):
loss = nloss
acc = accuracy(output, labels)
elif loss_type == 'GaussLoss':
loss, nloss, acc = gauss_loss(y, w, labels, loo_flag=model.loo_flag, cov_ac=model.plda.d_ac, enroll_type=model.enroll_type, r=model.r, N_dict=model.N_dict)
loss, nloss, acc = gauss_loss(y, w, labels, loo_flag=model.loo_flag, cov_ac=model.plda.d_ac, enroll_type=model.enroll_type, r=model.r)
elif loss_type == 'BinLoss':
loss, nloss, acc = bin_loss(y, w, labels, loo_flag=model.loo_flag, cov_ac=model.plda.d_ac, enroll_type=model.enroll_type, r=model.r, N_dict=model.N_dict)
loss, nloss, acc = bin_loss(y, w, labels, loo_flag=model.loo_flag, cov_ac=model.plda.d_ac, enroll_type=model.enroll_type, r=model.r)
else:
raise ValueError("Invalid loss type %s." % loss_type)

Expand Down Expand Up @@ -397,82 +414,96 @@ def bce_loss(LLR, labels, Pt=None):
# Gaussian diarization loss in minibatch
# Compute Gaussian cost across minibatch of samples vs. average
# Note: w is ignored in this version
def gauss_minibatch_ll(x1, w, labels, loo_flag=True, cov_wc1=None, cov_ac1=None, enroll_type='Bayes', r=0.9, N_dict=None, binary=False):

x = x1.cpu()
def gauss_minibatch_ll(x1, w, labels, loo_flag=True, cov_wc1=None, cov_ac1=None, enroll_type='Bayes', r=0.9, binary=False):
cpu_flag = True
if cpu_flag:
x = x1.cpu()
l2 = labels.clone().cpu()
else:
x = x1
l2 = labels.clone()
N = x.shape[0]
d = x.shape[1]
classes = list(set(labels.tolist()))
M = len(classes)
l2 = labels.clone().cpu()
sums = x.new_zeros((M,d))
counts = x.new_zeros((M,))
if cov_wc1 is None:
cov_wc = x.new_ones((d,))
else:
cov_wc = cov_wc1.cpu()
cov_wc1 = x.new_ones((d,))
if cov_ac1 is None or len(cov_ac1.shape) > 1:
cov_ac = x.new_ones((d,))
cov_ac1 = x.new_ones((d,))
if cpu_flag:
cov_wc = cov_wc1.cpu()
cov_ac = cov_ac1.cpu()
else:
cov_ac = cov_ac1.cpu()
cov_wc = cov_wc1
cov_ac = cov_ac1
cov_test = 1.0
if N_dict is None:
N_dict = {}
vb_flag = not loo_flag

# Compute stats for classes
for m in range(M):
l2[labels==classes[m]] = m
sums, counts = update_counts(x, l2, sums, counts, N0=1000, rand_flag=0)
l2[labels == classes[m]] = m

# Compute models and log-likelihoods
means, cov = gmm_adapt(counts, sums, cov_wc, cov_ac, r, enroll_type, N_dict)
LL = gmm_score(x, means, cov+cov_test)

# Leave one out corrections
# Compute GMM log-likelihoods
if loo_flag:

# Fast leave-one-out: same LL but fast gradients ignore centroids
sums, counts = compute_counts(x, l2)
means, cov = gmm_adapt(counts, sums, cov_wc, cov_ac, r, enroll_type)
LL = gmm_score_bayes(x, *gmm_adapt(counts, sums, cov_wc, cov_ac, r, enroll_type), cov_test, vb_flag)
for n in range(N):
m = classes.index(labels[n])
if counts[m] > 1:
mu_model, cov_model = gmm_adapt(counts[m:m+1]-1, sums[m:m+1,:]-x[n,:], cov_wc, cov_ac, r, enroll_type, N_dict)
LL[n,m] = gmm_score(x[n:n+1,:], mu_model, cov_model+cov_test)
LL[n, m] = gmm_score_bayes(x[n:n + 1, :],
*gmm_adapt(counts[m:m + 1] - 1, sums[m:m + 1, :] - x[n, :], cov_wc, cov_ac,
r, enroll_type),
cov_test, vb_flag)

else:
# Compute stats for classes
sums, counts = compute_counts(x, l2)

# Compute models and log-likelihoods
means, cov = gmm_adapt(counts, sums, cov_wc, cov_ac, r, enroll_type)

# Score models
LL = gmm_score_bayes(x, means, cov, cov_test, vb_flag)

if binary:
# Subtract open set LL
LL -= gmm_score(x, 0.0, cov_wc[None,:]+cov_ac[None,:])
prior = 1.0/M
LL -= gmm_score(x, 0.0, cov_wc[None, :] + cov_ac[None, :])
prior = 1.0 / M

else:
# Compute and apply prior
prior = counts/counts.sum()
prior = counts / counts.sum()
logprior = torch.log(prior)
LL += logprior

return LL, prior, l2


# Gaussian diarization loss in minibatch
def gauss_loss(x, w, labels, loo_flag=True, cov_wc=None, cov_ac=None, enroll_type='Bayes', r=0.9, N_dict=None):
def gauss_loss(x, w, labels, loo_flag=True, cov_wc=None, cov_ac=None, enroll_type='Bayes', r=0.9):

# Return normalized cross entropy cost
LL, prior, l2 = gauss_minibatch_ll(x, w, labels, loo_flag, cov_wc, cov_ac, enroll_type, r, N_dict)
LL, prior, l2 = gauss_minibatch_ll(x, w, labels, loo_flag, cov_wc, cov_ac, enroll_type, r)
loss, nloss = nce_loss(LL, l2, prior)
acc = accuracy(LL,l2)
return loss, nloss, acc


# Binary T/NT diarization loss in minibatch
def bin_loss(x, w, labels, loo_flag=True, cov_wc=None, cov_ac=None, enroll_type='Bayes', r=0.9, N_dict=None):
def bin_loss(x, w, labels, loo_flag=True, cov_wc=None, cov_ac=None, enroll_type='Bayes', r=0.9):

# Return normalized cross entropy cost
LL, prior, l2 = gauss_minibatch_ll(x, w, labels, loo_flag, cov_wc, cov_ac, enroll_type, r, N_dict, binary=True)
LL, prior, l2 = gauss_minibatch_ll(x, w, labels, loo_flag, cov_wc, cov_ac, enroll_type, r, binary=True)
loss, nloss = bce_loss(LL, l2)
acc = accuracy(LL,l2)
return loss, nloss, acc


# Function for Bayesian adaptation of Gaussian model
# Enroll type can be ML, MAP, or Bayes
def gmm_adapt(cnt, xsum, cov_wc, cov_ac, r=0, enroll_type='ML', N_dict=None):
def gmm_adapt(cnt, xsum, cov_wc, cov_ac, r=0, enroll_type='ML'):

# Compute ML model
cnt = torch.max(0*cnt+(1e-10),cnt)
Expand All @@ -490,7 +521,7 @@ def gmm_adapt(cnt, xsum, cov_wc, cov_ac, r=0, enroll_type='ML', N_dict=None):
elif r == 1:
Nsc = 0.0*cnt+1.0
else:
Nsc = compute_Nsc(cnt, r, N_dict)
Nsc = compute_Nsc(cnt, r)
cov_mean = cov_wc*Nsc[:,None]

# MAP mean plus model uncertainty
Expand All @@ -504,26 +535,10 @@ def gmm_adapt(cnt, xsum, cov_wc, cov_ac, r=0, enroll_type='ML', N_dict=None):
return mu_model, cov_model


def compute_Nsc(cnts, r, N_dict=None):
def compute_Nsc(cnts, r):

# Correlation model for enrollment cuts (0=none,1=single-cut)
if N_dict is None:
N_dict = {}
Nsc = cnts.clone()
icnt = (0.5+cnts.cpu().numpy()).astype(np.int)
for cnt in np.unique(icnt):
if cnt not in N_dict.keys():
if cnt < 1:
Neff = cnt
else:
Neff = (cnt*(1-r)+2*r) / (1.0+r)

N_dict[cnt] = 1.0 / Neff
print("cnt not in dict", cnt, N_dict[cnt])

# Update N_eff
mask = torch.from_numpy(np.array(icnt==cnt, dtype=np.uint8))
Nsc[mask] = N_dict[cnt]
Nsc = (1.0+r) / (cnts*(1-r)+2*r)

return Nsc

Expand All @@ -541,3 +556,11 @@ def gmm_score(X, means, covars):

return LLs

def gmm_score_bayes(X, means, cov_model, cov_wc, vb_flag=False):

LLs = gmm_score(X, means, cov_model+cov_wc)
if vb_flag:
vb_penalty = -0.5*(cov_model/cov_wc).sum(axis=1)
LLs += vb_penalty

return LLs
13 changes: 11 additions & 2 deletions xvectors/xvector_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def __init__(self, input_dim, layer_dim, embedding_dim, num_classes, LL='linear'
self.LL = LL
self.enroll_type = enroll_type
self.r = r
self.N_dict = {}
self.loo_flag = loo_flag
self.prepooling_frozen = False
self.embedding_frozen = False
Expand Down Expand Up @@ -204,7 +203,7 @@ def __init__(self, input_dim, layer_dim, embedding_dim, num_classes, LL='linear'
if enroll_type == 'ML':
self.output = GaussLinear(embedding_dim, num_classes, N0, fixed_N)
else:
self.output = GaussQuadratic(embedding_dim, num_classes, N0, fixed_N, r, enroll_type, self.N_dict)
self.output = GaussQuadratic(embedding_dim, num_classes, N0, fixed_N, r, enroll_type)

elif self.LL == 'Gauss_discr':
# Gaussian discriminative means
Expand Down Expand Up @@ -258,6 +257,16 @@ def update_params(self, x, y, z, labels):
self.output.update_params(y, labels)
return

def freeze_prepooling(self):

model = self.embed
if hasattr(model, 'module'):
model = model.module
freeze_prepooling(model)

return


### Resnet embedding functions
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
Expand Down

0 comments on commit fc9b498

Please sign in to comment.