Skip to content

Commit

Permalink
Self-estimated Speech Augmentation (multi-optimization) to bsrnn_mult…
Browse files Browse the repository at this point in the history
…i_optim.yaml
  • Loading branch information
mrjunjieli committed Sep 21, 2024
1 parent 8527114 commit 8d8c9ac
Show file tree
Hide file tree
Showing 6 changed files with 590 additions and 8 deletions.
5 changes: 2 additions & 3 deletions examples/librimix/tse/v2/confs/bsrnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ dataset_args:
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech
# Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589
# only Single-optimization method is supported here.
# if you want to use multi-optimization, please ref bsrnn_multi_optimization.yaml
SSA_enroll_prob:
Single_optimization: 0.6 # prob to add SSA on enrollment speech
# if you want to use multi-optimization, please ref bsrnn_multi_optim.yaml
SSA_enroll_prob: 0 # prob to add SSA on enrollment speech

enable_amp: false
exp_dir: exp/BSRNN
Expand Down
118 changes: 118 additions & 0 deletions examples/librimix/tse/v2/confs/bsrnn_multi_optim.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
dataloader_args:
batch_size: 8 #RTX2080 1, V100: 8, A800: 16
drop_last: true
num_workers: 6
pin_memory: true
prefetch_factor: 6

dataset_args:
resample_rate: &sr 16000
sample_num_per_epoch: 0
shuffle: true
shuffle_args:
shuffle_size: 2500
chunk_len: 48000
speaker_feat: &speaker_feat False
fbank_args:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
noise_lmdb_file: './data/musan/lmdb'
noise_prob: 0 # prob to add noise aug per sample
specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech
reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech
noise_enroll_prob: 0 # prob to add noise aug on enrollment speech


enable_amp: false
exp_dir: exp/BSRNN
gpus: '0,1'
log_batch_interval: 100

#Please refer to our SLT paper https://www.arxiv.org/abs/2409.09589
# to check our parameter settings.
loss: SISDR
loss_args:
loss_posi: [[0,1]]
loss_weight: [[0.4,0.6]]

#if you wanna use CE loss, multi_task needs to be set True
# loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set
# loss_args:
# loss_posi: [[0,1],[2,3]]
# loss_weight: [[0.36,0.54],[0.04,0.06]]

model:
tse_model: BSRNN_Multi
model_args:
tse_model:
sr: *sr
win: 512
stride: 128
feature_dim: 128
num_repeat: 6
spk_fuse_type: 'multiply'
use_spk_transform: False
multi_fuse: False # Fuse the speaker embedding multiple times.
joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders
####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md
spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt
spk_args:
feat_dim: 80
embed_dim: &embed_dim 256
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
####### Ecapa_TDNN
# spk_model: ECAPA_TDNN_GLOB_c512
# spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt
# spk_args:
# embed_dim: &embed_dim 192
# feat_dim: 80
# pooling_func: ASTP
####### CAMPPlus
# spk_model: CAMPPlus
# spk_model_init: False
# spk_args:
# feat_dim: 80
# embed_dim: &embed_dim 192
# pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
#################################################################
spk_emb_dim: *embed_dim
spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder
spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used
feat_type: "consistent"
multi_task: False
spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921

# find_unused_parameters: True

model_init:
tse_model: null
discriminator: null
spk_model: null

num_avg: 2
num_epochs: 150

optimizer:
tse_model: Adam
optimizer_args:
tse_model:
lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently!
weight_decay: 0.0001

clip_grad: 5.0
save_epoch_interval: 1

scheduler:
tse_model: ExponentialDecrease
scheduler_args:
tse_model:
final_lr: 2.5e-05
initial_lr: 0.001
warm_from_zero: false
warm_up_epoch: 0

seed: 42
2 changes: 1 addition & 1 deletion wesep/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def train(config="conf/config.yaml", **kwargs):
device=device,
se_loss_weight=loss_args,
multi_task=multi_task,
SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", None),
SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", 0),
fbank_args= configs["dataset_args"].get('fbank_args',None),
sample_rate=configs["dataset_args"]['resample_rate'],
speaker_feat = configs["dataset_args"].get('speaker_feat', True)
Expand Down
4 changes: 3 additions & 1 deletion wesep/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import wesep.models.dpccn as dpccn
import wesep.models.tfgridnet as tfgridnet
import wesep.modules.metric_gan.discriminator as discriminator

import wesep.models.bsrnn_multi_optim as bsrnn_multi

def get_model(model_name: str):
if model_name.startswith("ConvTasNet"):
return getattr(convtasnet, model_name)
elif model_name.startswith("BSRNN_Multi"):
return getattr(bsrnn_multi,model_name)
elif model_name.startswith("BSRNN"):
return getattr(bsrnn, model_name)
elif model_name.startswith("DPCCN"):
Expand Down
Loading

0 comments on commit 8d8c9ac

Please sign in to comment.