Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vec]add vector necessary note, test=doc #1690

Merged
merged 1 commit into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
augment: True
batch_size: 32
num_workers: 2
num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
num_speakers: 7205 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True
skip_prep: False
split_ratio: 0.9
Expand Down Expand Up @@ -42,8 +42,16 @@ epochs: 10
save_interval: 10
log_interval: 10
learning_rate: 1e-8
max_lr: 1e-3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

字段后面多加些注释

step_size: 140000


###########################################
# loss #
###########################################
margin: 0.2
scale: 30

###########################################
# Testing #
###########################################
Expand Down
9 changes: 8 additions & 1 deletion examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Data #
###########################################
augment: True
batch_size: 16
batch_size: 32
num_workers: 2
num_speakers: 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle: True
Expand Down Expand Up @@ -42,7 +42,14 @@ epochs: 100
save_interval: 10
log_interval: 10
learning_rate: 1e-8
max_lr: 1e-3
step_size: 140000

###########################################
# loss #
###########################################
margin: 0.2
scale: 30

###########################################
# Testing #
Expand Down
57 changes: 41 additions & 16 deletions paddlespeech/vector/exps/ecapa_tdnn/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
"""compute the dataset embeddings

Args:
data_loader (_type_): _description_
model (_type_): _description_
mean_var_norm_emb (_type_): _description_
config (_type_): _description_
data_loader (paddle.io.Dataloader): the dataset loader to be compute the embedding
model (paddle.nn.Layer): the speaker verification model
mean_var_norm_emb : compute the embedding mean and std norm
config (yacs.config.CfgNode): the yaml config
"""
logger.info(
f'Computing embeddings on {data_loader.dataset.csv_path} dataset')
Expand All @@ -65,6 +65,17 @@ def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,


def compute_verification_scores(id2embedding, train_cohort, config):
"""Compute the verification trial scores

Args:
id2embedding (dict): the utterance embedding
train_cohort (paddle.tensor): the cohort dataset embedding
config (yacs.config.CfgNode): the yaml config

Returns:
the scores and the trial labels,
1 refers the target and 0 refers the nontarget in labels
"""
labels = []
enroll_ids = []
test_ids = []
Expand Down Expand Up @@ -119,20 +130,32 @@ def compute_verification_scores(id2embedding, train_cohort, config):


def main(args, config):
"""The main process for test the speaker verification model

Args:
args (argparse.Namespace): the command line args namespace
config (yacs.config.CfgNode): the yaml config
"""

# stage0: set the training device, cpu or gpu
# if set the gpu, paddlespeech will select a gpu according the env CUDA_VISIBLE_DEVICES
paddle.set_device(args.device)
# set the random seed, it is a must for multiprocess training
# set the random seed, it is the necessary measures for multiprocess training
seed_everything(config.seed)

# stage1: build the dnn backbone model network
# we will extract the audio embedding from the backbone model
ecapa_tdnn = EcapaTdnn(**config.model)

# stage2: build the speaker verification eval instance with backbone model
# because the checkpoint dict name has the SpeakerIdetification prefix
# so we need to create the SpeakerIdetification instance
# but we acutally use the backbone model to extact the audio embedding
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=config.num_speakers)

# stage3: load the pre-trained model
# we get the last model from the epoch and save_interval
# generally, we get the last model from the epoch
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))

Expand All @@ -143,7 +166,8 @@ def main(args, config):
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')

# stage4: construct the enroll and test dataloader

# Now, wo think the enroll dataset is in the {args.data_dir}/vox/csv/enroll.csv,
# and the test dataset is in the {args.data_dir}/vox/csv/test.csv
enroll_dataset = CSVDataset(
os.path.join(args.data_dir, "vox/csv/enroll.csv"),
feat_type='melspectrogram',
Expand All @@ -152,22 +176,21 @@ def main(args, config):
window_size=config.window_size,
hop_length=config.hop_size)
enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size,
shuffle=False) # Shuffle to make embedding normalization more robust.
enroll_dataset, batch_size=config.batch_size, shuffle=False)
enroll_loader = DataLoader(enroll_dataset,
batch_sampler=enroll_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)

test_dataset = CSVDataset(
os.path.join(args.data_dir, "vox/csv/test.csv"),
feat_type='melspectrogram',
random_chunk=False,
n_mels=config.n_mels,
window_size=config.window_size,
hop_length=config.hop_size)

test_sampler = BatchSampler(
test_dataset, batch_size=config.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset,
Expand All @@ -180,16 +203,17 @@ def main(args, config):
model.eval()

# stage6: global embedding norm to imporve the performance
# and we create the InputNormalization instance to process the embedding mean and std norm
logger.info(f"global embedding norm: {config.global_embedding_norm}")

# stage7: Compute embeddings of audios in enrol and test dataset from model.

if config.global_embedding_norm:
mean_var_norm_emb = InputNormalization(
norm_type="global",
mean_norm=config.embedding_mean_norm,
std_norm=config.embedding_std_norm)

# stage 7: score norm need the imposters dataset
# we select the train dataset as the idea imposters dataset
# and we select the config.n_train_snts utterance to as the final imposters dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好按照kaldi的名字风格命名, n_train_utts。看声纹代码里有很多命名方式很传统,不好理解。

if "score_norm" in config:
logger.info(f"we will do score norm: {config.score_norm}")
train_dataset = CSVDataset(
Expand All @@ -209,6 +233,7 @@ def main(args, config):
num_workers=config.num_workers,
return_list=True,)

# stage 8: Compute embeddings of audios in enrol and test dataset from model.
id2embedding = {}
# Run multi times to make embedding normalization more stable.
logger.info("First loop for enroll and test dataset")
Expand All @@ -225,7 +250,7 @@ def main(args, config):
mean_var_norm_emb.save(
os.path.join(args.load_checkpoint, "mean_var_norm_emb"))

# stage 8: Compute cosine scores.
# stage 9: Compute cosine scores.
train_cohort = None
if "score_norm" in config:
train_embeddings = {}
Expand All @@ -234,11 +259,11 @@ def main(args, config):
train_embeddings)
train_cohort = paddle.stack(list(train_embeddings.values()))

# compute the scores
# stage 10: compute the scores
scores, labels = compute_verification_scores(id2embedding, train_cohort,
config)

# compute the EER and threshold
# stage 11: compute the EER and threshold
scores = paddle.to_tensor(scores)
EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
logger.info(
Expand Down
20 changes: 15 additions & 5 deletions paddlespeech/vector/exps/ecapa_tdnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,39 @@


def main(args, config):
"""The main process for test the speaker verification model

Args:
args (argparse.Namespace): the command line args namespace
config (yacs.config.CfgNode): the yaml config
"""
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)

# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
paddle.distributed.init_parallel_env()
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
# set the random seed, it is a must for multiprocess training
# set the random seed, it is the necessary measures for multiprocess training
seed_everything(config.seed)

# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
# note: some operations must be done in rank==0
train_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/train.csv"),
label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))
dev_dataset = CSVDataset(
csv_path=os.path.join(args.data_dir, "vox/csv/dev.csv"),
label2id_path=os.path.join(args.data_dir, "vox/meta/label2id.txt"))

# we will build the augment pipeline process list
if config.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
else:
augment_pipeline = []

# stage3: build the dnn backbone model network
# in speaker verification period, we use the backbone mode to extract the audio embedding
ecapa_tdnn = EcapaTdnn(**config.model)

# stage4: build the speaker verification train instance with backbone model
Expand All @@ -77,13 +85,15 @@ def main(args, config):
# 140000 is single gpu steps
# so, in multi-gpu mode, wo reduce the step_size to 140000//nranks to enable CyclicLRScheduler
lr_schedule = CyclicLRScheduler(
base_lr=config.learning_rate, max_lr=1e-3, step_size=140000 // nranks)
base_lr=config.learning_rate,
max_lr=config.max_lr,
step_size=config.step_size // nranks)
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_schedule, parameters=model.parameters())

# stage6: build the loss function, we now only support LogSoftmaxWrapper
criterion = LogSoftmaxWrapper(
loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))
loss_fn=AdditiveAngularMargin(margin=config.margin, scale=config.scale))

# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
Expand Down Expand Up @@ -225,7 +235,7 @@ def main(args, config):
print_msg += ' avg_train_cost: {:.5f} sec,'.format(
train_run_cost / config.log_interval)

print_msg += ' lr={:.4E} step/sec={:.2f} ips:{:.5f}| ETA {}'.format(
print_msg += ' lr={:.4E} step/sec={:.2f} ips={:.5f}| ETA {}'.format(
lr, timer.timing, timer.ips, timer.eta)
logger.info(print_msg)

Expand Down
7 changes: 5 additions & 2 deletions paddlespeech/vector/io/embedding_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def __call__(self,
lengths (paddle.Tensor): A batch of tensors containing the relative length of each
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
computing stats on zero-padded steps.
spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
spk_ids (paddle.Tensor, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
It is used to perform per-speaker normalization when
norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32").
Returns:
paddle.Tensor: The normalized feature or embedding
"""
N_batches = x.shape[0]
# print(f"x shape: {x.shape[1]}")

current_means = []
current_stds = []

Expand All @@ -75,6 +75,9 @@ def __call__(self,
actual_size = paddle.round(lengths[snt_id] *
x.shape[1]).astype("int32")
# computing actual time data statistics
# we extract the snt_id embedding from the x
# and the target paddle.Tensor will reduce an 0-axis
# so we need unsqueeze operation to recover the all axis
current_mean, current_std = self._compute_current_stats(
x[snt_id, 0:actual_size, ...].unsqueeze(0))
current_means.append(current_mean)
Expand Down