-
Notifications
You must be signed in to change notification settings - Fork 29
/
ivector_PLDA_CSI.py
136 lines (100 loc) · 5.49 KB
/
ivector_PLDA_CSI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
## Copyright (C) 2019, Guangke Chen <gkchen.shanghaitech@gmail.com>.
## This program is licenced under the BSD 2-Clause licence
## contained in the LICENCE file in this directory.
from ivector_PLDA_kaldiHelper import ivector_PLDA_kaldiHelper
import numpy as np
import os
import shutil
import time
import subprocess
import shlex
import copy
bits_per_sample = 16
class iv_CSI:
def __init__(self, group_id, model_list, pre_model_dir="pre-models"):
self.pre_model_dir = os.path.abspath(pre_model_dir)
self.group_id = os.path.abspath(group_id)
if not os.path.exists(self.group_id):
os.makedirs(self.group_id)
self.audio_dir = os.path.abspath(self.group_id + "/audio")
self.mfcc_dir = os.path.abspath(self.group_id + "/mfcc")
self.log_dir = os.path.abspath(self.group_id + "/log")
self.ivector_dir = os.path.abspath(self.group_id + "/ivector")
self.n_speakers = len(model_list)
self.spk_ids = []
self.utt_ids = []
self.identity_locations = []
self.z_norm_means = np.zeros(self.n_speakers, dtype=np.float64)
self.z_norm_stds = np.zeros(self.n_speakers, dtype=np.float64)
for i, model in enumerate(model_list):
spk_id = model[0]
utt_id = model[1]
identity_location = model[2]
mean = model[3]
std = model[4]
self.spk_ids.append(spk_id)
self.utt_ids.append(utt_id)
self.identity_locations.append(identity_location)
self.z_norm_means[i] = mean
self.z_norm_stds[i] = std
''' make sure self.ids is in order, otherwise kaldi may oder them, which may leads to wrong results
'''
self.spk_ids, self.utt_ids, self.identity_locations, self.z_norm_means, self.z_norm_stds = \
self.order(self.spk_ids, self.utt_ids, self.identity_locations, self.z_norm_means, self.z_norm_stds)
self.train_ivector_scp = self.group_id + "/ivector.scp"
np.savetxt(self.train_ivector_scp, np.concatenate((np.array(self.utt_ids)[:, np.newaxis], np.array(self.identity_locations)[:, np.newaxis]), axis=1), fmt="%s")
self.kaldi_helper = ivector_PLDA_kaldiHelper(pre_model_dir=self.pre_model_dir, audio_dir=self.audio_dir, mfcc_dir=self.mfcc_dir, log_dir=self.log_dir, ivector_dir=self.ivector_dir)
def order(self, spk_ids, utt_ids, identity_locations, z_norm_means, z_norm_stds):
spk_ids_sort = copy.deepcopy(spk_ids)
utt_ids_sort = copy.deepcopy(utt_ids)
identity_locations_sort = copy.deepcopy(identity_locations)
z_norm_means_sort = copy.deepcopy(z_norm_means)
z_norm_stds_sort = copy.deepcopy(z_norm_stds)
spk_ids_sort.sort()
for i, spk_id in enumerate(spk_ids_sort):
index = np.argwhere(np.array(spk_ids) == spk_id).flatten()[0]
utt_ids_sort[i] = utt_ids[index]
identity_locations_sort[i] = identity_locations[index]
z_norm_means_sort[i] = z_norm_means[index]
z_norm_stds_sort[i] = z_norm_stds[index]
return spk_ids_sort, utt_ids_sort, identity_locations_sort, z_norm_means_sort, z_norm_stds_sort
def score(self, audio_list, fs=16000, bits_per_sample=16, n_jobs=10, debug=False):
if os.path.exists(self.audio_dir):
shutil.rmtree(self.audio_dir)
if os.path.exists(self.mfcc_dir):
shutil.rmtree(self.mfcc_dir)
if os.path.exists(self.log_dir):
shutil.rmtree(self.log_dir)
if os.path.exists(self.ivector_dir):
shutil.rmtree(self.ivector_dir)
if not os.path.exists(self.audio_dir):
os.makedirs(self.audio_dir)
if not os.path.exists(self.mfcc_dir):
os.makedirs(self.mfcc_dir)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
if not os.path.exists(self.ivector_dir):
os.makedirs(self.ivector_dir)
if isinstance(audio_list, np.ndarray):
if len(audio_list.shape) == 1 or (len(audio_list.shape) == 2 and (audio_list.shape[0] == 1 or audio_list.shape[1] == 1)):
audio_list = [audio_list]
else:
audio_list = [audio_list[:, i] for i in range(audio_list.shape[1])]
else:
audio_list = copy.deepcopy(audio_list) # avoid influencing
for i, audio in enumerate(audio_list):
if not audio.dtype == np.int16:
audio_list[i] = (audio * (2 ** (bits_per_sample - 1))).astype(np.int16)
score_array = self.kaldi_helper.score(audio_list, self.utt_ids, n_jobs=n_jobs, flag=1, train_ivector_scp=self.train_ivector_scp, debug=debug)
score_array = (score_array - self.z_norm_means) / self.z_norm_stds
return score_array # (n_audios, n_spks) or (n_spks, )
def make_decisions(self, audios, fs=16000, bits_per_sample=16, n_jobs=10, debug=False):
score_array = self.score(audios, fs=fs, bits_per_sample=bits_per_sample, n_jobs=n_jobs, debug=debug)
if len(score_array.shape) == 1:
score_array = score_array[np.newaxis, :]
decisions = np.argmax(score_array, axis=1)
decisions = list(decisions)
if len(decisions) == 1:
decisions = decisions[0]
score_array = score_array.flatten()
return decisions, score_array