diff --git a/examples/librimix/tse/v2/confs/bsrnn.yaml b/examples/librimix/tse/v2/confs/bsrnn.yaml index f8b9791..9a2f4d4 100644 --- a/examples/librimix/tse/v2/confs/bsrnn.yaml +++ b/examples/librimix/tse/v2/confs/bsrnn.yaml @@ -25,9 +25,8 @@ dataset_args: noise_enroll_prob: 0 # prob to add noise aug on enrollment speech # Self-estimated Speech Augmentation (SSA). Please ref our SLT paper: https://www.arxiv.org/abs/2409.09589 # only Single-optimization method is supported here. - # if you want to use multi-optimization, please ref bsrnn_multi_optimization.yaml - SSA_enroll_prob: - Single_optimization: 0.6 # prob to add SSA on enrollment speech + # if you want to use multi-optimization, please ref bsrnn_multi_optim.yaml + SSA_enroll_prob: 0 # prob to add SSA on enrollment speech enable_amp: false exp_dir: exp/BSRNN diff --git a/examples/librimix/tse/v2/confs/bsrnn_multi_optim.yaml b/examples/librimix/tse/v2/confs/bsrnn_multi_optim.yaml new file mode 100644 index 0000000..f390a02 --- /dev/null +++ b/examples/librimix/tse/v2/confs/bsrnn_multi_optim.yaml @@ -0,0 +1,118 @@ +dataloader_args: + batch_size: 8 #RTX2080 1, V100: 8, A800: 16 + drop_last: true + num_workers: 6 + pin_memory: true + prefetch_factor: 6 + +dataset_args: + resample_rate: &sr 16000 + sample_num_per_epoch: 0 + shuffle: true + shuffle_args: + shuffle_size: 2500 + chunk_len: 48000 + speaker_feat: &speaker_feat False + fbank_args: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + noise_lmdb_file: './data/musan/lmdb' + noise_prob: 0 # prob to add noise aug per sample + specaug_enroll_prob: 0 # prob to apply SpecAug on fbank of enrollment speech + reverb_enroll_prob: 0 # prob to add reverb aug on enrollment speech + noise_enroll_prob: 0 # prob to add noise aug on enrollment speech + + +enable_amp: false +exp_dir: exp/BSRNN +gpus: '0,1' +log_batch_interval: 100 + +#Please refer to our SLT paper https://www.arxiv.org/abs/2409.09589 +# to check our parameter settings. +loss: SISDR +loss_args: + loss_posi: [[0,1]] + loss_weight: [[0.4,0.6]] + +#if you wanna use CE loss, multi_task needs to be set True +# loss: [SISDR, CE] ### For joint training the speaker encoder with CE loss. Put SISDR in the first position for validation set +# loss_args: +# loss_posi: [[0,1],[2,3]] +# loss_weight: [[0.36,0.54],[0.04,0.06]] + +model: + tse_model: BSRNN_Multi +model_args: + tse_model: + sr: *sr + win: 512 + stride: 128 + feature_dim: 128 + num_repeat: 6 + spk_fuse_type: 'multiply' + use_spk_transform: False + multi_fuse: False # Fuse the speaker embedding multiple times. + joint_training: True # Always set True, use "spk_model_freeze" to control if use pre-trained speaker encoders + ####### ResNet The pretrained speaker encoders are available from: https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md + spk_model: ResNet34 # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 + spk_model_init: False #./wespeaker_models/voxceleb_resnet34/avg_model.pt + spk_args: + feat_dim: 80 + embed_dim: &embed_dim 256 + pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP + two_emb_layer: False + ####### Ecapa_TDNN + # spk_model: ECAPA_TDNN_GLOB_c512 + # spk_model_init: False #./wespeaker_models/voxceleb_ECAPA512/avg_model.pt + # spk_args: + # embed_dim: &embed_dim 192 + # feat_dim: 80 + # pooling_func: ASTP + ####### CAMPPlus + # spk_model: CAMPPlus + # spk_model_init: False + # spk_args: + # feat_dim: 80 + # embed_dim: &embed_dim 192 + # pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP + ################################################################# + spk_emb_dim: *embed_dim + spk_model_freeze: False # Related to train TSE model with pre-trained speaker encoder, Control if freeze the weights in speaker encoder + spk_feat: *speaker_feat #if do NOT wanna process the feat when processing data, set &speaker_feat to False, then the feat_type will be used + feat_type: "consistent" + multi_task: False + spksInTrain: 251 #wsj0-2mix: 101; Libri2mix-100: 251; Libri2mix-360:921 + +# find_unused_parameters: True + +model_init: + tse_model: null + discriminator: null + spk_model: null + +num_avg: 2 +num_epochs: 150 + +optimizer: + tse_model: Adam +optimizer_args: + tse_model: + lr: 0.001 # NOTICE: These args do NOT work! The initial lr is determined in the scheduler_args currently! + weight_decay: 0.0001 + +clip_grad: 5.0 +save_epoch_interval: 1 + +scheduler: + tse_model: ExponentialDecrease +scheduler_args: + tse_model: + final_lr: 2.5e-05 + initial_lr: 0.001 + warm_from_zero: false + warm_up_epoch: 0 + +seed: 42 diff --git a/wesep/bin/train.py b/wesep/bin/train.py index 65048c5..36a2260 100644 --- a/wesep/bin/train.py +++ b/wesep/bin/train.py @@ -320,7 +320,7 @@ def train(config="conf/config.yaml", **kwargs): device=device, se_loss_weight=loss_args, multi_task=multi_task, - SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", None), + SSA_enroll_prob=configs["dataset_args"].get("SSA_enroll_prob", 0), fbank_args= configs["dataset_args"].get('fbank_args',None), sample_rate=configs["dataset_args"]['resample_rate'], speaker_feat = configs["dataset_args"].get('speaker_feat', True) diff --git a/wesep/models/__init__.py b/wesep/models/__init__.py index 2d59fc5..6694734 100644 --- a/wesep/models/__init__.py +++ b/wesep/models/__init__.py @@ -3,11 +3,13 @@ import wesep.models.dpccn as dpccn import wesep.models.tfgridnet as tfgridnet import wesep.modules.metric_gan.discriminator as discriminator - +import wesep.models.bsrnn_multi_optim as bsrnn_multi def get_model(model_name: str): if model_name.startswith("ConvTasNet"): return getattr(convtasnet, model_name) + elif model_name.startswith("BSRNN_Multi"): + return getattr(bsrnn_multi,model_name) elif model_name.startswith("BSRNN"): return getattr(bsrnn, model_name) elif model_name.startswith("DPCCN"): diff --git a/wesep/models/bsrnn_multi_optim.py b/wesep/models/bsrnn_multi_optim.py new file mode 100644 index 0000000..8a0004d --- /dev/null +++ b/wesep/models/bsrnn_multi_optim.py @@ -0,0 +1,464 @@ +from __future__ import print_function +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torchaudio +from wespeaker.models.speaker_model import get_speaker_model + +from wesep.modules.common.speaker import PreEmphasis +from wesep.modules.common.speaker import SpeakerFuseLayer +from wesep.modules.common.speaker import SpeakerTransform + +class ResRNN(nn.Module): + + def __init__(self, input_size, hidden_size, bidirectional=True): + super(ResRNN, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.eps = torch.finfo(torch.float32).eps + + self.norm = nn.GroupNorm(1, input_size, self.eps) + self.rnn = nn.LSTM( + input_size, + hidden_size, + 1, + batch_first=True, + bidirectional=bidirectional, + ) + + # linear projection layer + self.proj = nn.Linear(hidden_size * 2, + input_size) # hidden_size = feature_dim * 2 + + def forward(self, input): + # input shape: batch, dim, seq + + rnn_output, _ = self.rnn(self.norm(input).transpose(1, 2).contiguous()) + rnn_output = self.proj(rnn_output.contiguous().view( + -1, rnn_output.shape[2])).view(input.shape[0], input.shape[2], + input.shape[1]) + + return input + rnn_output.transpose(1, 2).contiguous() + + +""" +TODO : attach the speaker embedding to each input +Input shape:(B,feature_dim + spk_emb_dim , T) +""" + + +class BSNet(nn.Module): + + def __init__(self, in_channel, nband=7, bidirectional=True): + super(BSNet, self).__init__() + + self.nband = nband + self.feature_dim = in_channel // nband + self.band_rnn = ResRNN(self.feature_dim, + self.feature_dim * 2, + bidirectional=bidirectional) + self.band_comm = ResRNN(self.feature_dim, + self.feature_dim * 2, + bidirectional=bidirectional) + + def forward(self, input, dummy: Optional[torch.Tensor] = None): + # input shape: B, nband*N, T + B, N, T = input.shape + + band_output = self.band_rnn( + input.view(B * self.nband, self.feature_dim, + -1)).view(B, self.nband, -1, T) + + # band comm + band_output = (band_output.permute(0, 3, 2, 1).contiguous().view( + B * T, -1, self.nband)) + output = (self.band_comm(band_output).view( + B, T, -1, self.nband).permute(0, 3, 2, 1).contiguous()) + + return output.view(B, N, T) + + +class FuseSeparation(nn.Module): + + def __init__( + self, + nband=7, + num_repeat=6, + feature_dim=128, + spk_emb_dim=256, + spk_fuse_type="concat", + multi_fuse=True, + ): + """ + + :param nband : len(self.band_width) + """ + super(FuseSeparation, self).__init__() + self.multi_fuse = multi_fuse + self.nband = nband + self.feature_dim = feature_dim + self.separation = nn.ModuleList([]) + if self.multi_fuse: + for _ in range(num_repeat): + self.separation.append( + SpeakerFuseLayer( + embed_dim=spk_emb_dim, + feat_dim=feature_dim, + fuse_type=spk_fuse_type, + )) + self.separation.append(BSNet(nband * feature_dim, nband)) + else: + self.separation.append( + SpeakerFuseLayer( + embed_dim=spk_emb_dim, + feat_dim=feature_dim, + fuse_type=spk_fuse_type, + )) + for _ in range(num_repeat): + self.separation.append(BSNet(nband * feature_dim, nband)) + + def forward(self, x, spk_embedding, nch: torch.Tensor = torch.tensor(1)): + """ + x: [B, nband, feature_dim, T] + out: [B, nband, feature_dim, T] + """ + batch_size = x.shape[0] + + if self.multi_fuse: + for i, sep_func in enumerate(self.separation): + x = sep_func(x, spk_embedding) + if i % 2 == 0: + x = x.view(batch_size * nch, self.nband * self.feature_dim, + -1) + else: + x = x.view(batch_size * nch, self.nband, self.feature_dim, + -1) + else: + x = self.separation[0](x, spk_embedding) + x = x.view(batch_size * nch, self.nband * self.feature_dim, -1) + for idx, sep in enumerate(self.separation): + if idx > 0: + x = sep(x, spk_embedding) + x = x.view(batch_size * nch, self.nband, self.feature_dim, -1) + return x + + +class BSRNN_Multi(nn.Module): + # self, sr=16000, win=512, stride=128, feature_dim=128, num_repeat=6, + # use_bidirectional=True + def __init__( + self, + spk_emb_dim=256, + sr=16000, + win=512, + stride=128, + feature_dim=128, + num_repeat=6, + use_spk_transform=True, + use_bidirectional=True, + spk_fuse_type="concat", + multi_fuse=True, + joint_training=True, + multi_task=False, + spksInTrain=251, + spk_model=None, + spk_model_init=None, + spk_model_freeze=False, + spk_args=None, + spk_feat=False, + feat_type="consistent", + ): + super(BSRNN_Multi, self).__init__() + + self.sr = sr + self.win = win + self.stride = stride + self.group = self.win // 2 + self.enc_dim = self.win // 2 + 1 + self.feature_dim = feature_dim + self.eps = torch.finfo(torch.float32).eps + self.spk_emb_dim = spk_emb_dim + self.joint_training = joint_training + self.spk_feat = spk_feat + self.feat_type = feat_type + self.spk_model_freeze = spk_model_freeze + self.multi_task = multi_task + + # 0-1k (100 hop), 1k-4k (250 hop), + # 4k-8k (500 hop), 8k-16k (1k hop), + # 16k-20k (2k hop), 20k-inf + + # 0-8k (1k hop), 8k-16k (2k hop), 16k + bandwidth_100 = int(np.floor(100 / (sr / 2.0) * self.enc_dim)) + bandwidth_200 = int(np.floor(200 / (sr / 2.0) * self.enc_dim)) + bandwidth_500 = int(np.floor(500 / (sr / 2.0) * self.enc_dim)) + bandwidth_2k = int(np.floor(2000 / (sr / 2.0) * self.enc_dim)) + + # add up to 8k + self.band_width = [bandwidth_100] * 15 + self.band_width += [bandwidth_200] * 10 + self.band_width += [bandwidth_500] * 5 + self.band_width += [bandwidth_2k] * 1 + + self.band_width.append(self.enc_dim - int(np.sum(self.band_width))) + self.nband = len(self.band_width) + + if use_spk_transform: + self.spk_transform = SpeakerTransform() + else: + self.spk_transform = nn.Identity() + + if joint_training: + self.spk_model = get_speaker_model(spk_model)(**spk_args) + if spk_model_init: + pretrained_model = torch.load(spk_model_init) + state = self.spk_model.state_dict() + for key in state.keys(): + if key in pretrained_model.keys(): + state[key] = pretrained_model[key] + # print(key) + else: + print("not %s loaded" % key) + self.spk_model.load_state_dict(state) + if spk_model_freeze: + for param in self.spk_model.parameters(): + param.requires_grad = False + if not spk_feat: + if feat_type == "consistent": + self.preEmphasis = PreEmphasis() + self.spk_encoder = torchaudio.transforms.MelSpectrogram( + sample_rate=sr, + n_fft=win, + win_length=win, + hop_length=stride, + f_min=20, + window_fn=torch.hamming_window, + n_mels=spk_args["feat_dim"], + ) + else: + self.preEmphasis = nn.Identity() + self.spk_encoder = nn.Identity() + + if multi_task: + self.pred_linear = nn.Linear(spk_emb_dim, spksInTrain) + else: + self.pred_linear = nn.Identity() + + self.BN = nn.ModuleList([]) + for i in range(self.nband): + self.BN.append( + nn.Sequential( + nn.GroupNorm(1, self.band_width[i] * 2, self.eps), + nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1), + )) + + self.separator = FuseSeparation( + nband=self.nband, + num_repeat=num_repeat, + feature_dim=feature_dim, + spk_emb_dim=spk_emb_dim, + spk_fuse_type=spk_fuse_type, + multi_fuse=multi_fuse, + ) + + # self.proj = nn.Linear(hidden_size*2, input_size) + + self.mask = nn.ModuleList([]) + for i in range(self.nband): + self.mask.append( + nn.Sequential( + nn.GroupNorm(1, self.feature_dim, + torch.finfo(torch.float32).eps), + nn.Conv1d(self.feature_dim, self.feature_dim * 4, 1), + nn.Tanh(), + nn.Conv1d(self.feature_dim * 4, self.feature_dim * 4, 1), + nn.Tanh(), + nn.Conv1d(self.feature_dim * 4, self.band_width[i] * 4, 1), + )) + + def pad_input(self, input, window, stride): + """ + Zero-padding input according to window/stride size. + """ + batch_size, nsample = input.shape + + # pad the signals at the end for matching the window/stride size + rest = window - (stride + nsample % window) % window + if rest > 0: + pad = torch.zeros(batch_size, rest).type(input.type()) + input = torch.cat([input, pad], 1) + pad_aux = torch.zeros(batch_size, stride).type(input.type()) + input = torch.cat([pad_aux, input, pad_aux], 1) + + return input, rest + + def forward(self, input, embeddings): + # input shape: (B, C, T) + + wav_input = input + spk_emb_input = embeddings + batch_size, nsample = wav_input.shape + nch = 1 + + # frequency-domain separation + spec = torch.stft( + wav_input, + n_fft=self.win, + hop_length=self.stride, + window=torch.hann_window(self.win).to(wav_input.device).type( + wav_input.type()), + return_complex=True, + ) + + # concat real and imag, split to subbands + spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T + subband_spec = [] + subband_mix_spec = [] + band_idx = 0 + for i in range(len(self.band_width)): + subband_spec.append(spec_RI[:, :, band_idx:band_idx + + self.band_width[i]].contiguous()) + subband_mix_spec.append(spec[:, band_idx:band_idx + + self.band_width[i]]) # B*nch, BW, T + band_idx += self.band_width[i] + + # normalization and bottleneck + subband_feature = [] + for i, bn_func in enumerate(self.BN): + subband_feature.append( + bn_func(subband_spec[i].view(batch_size * nch, + self.band_width[i] * 2, -1))) + subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T + # print(subband_feature.size(), spk_emb_input.size()) + + predict_speaker_lable = torch.tensor(0.0).to( + spk_emb_input.device) # dummy + if self.joint_training: + if not self.spk_feat: + if self.feat_type == "consistent": + with torch.no_grad(): + spk_emb_input = self.preEmphasis(spk_emb_input) + spk_emb_input = self.spk_encoder(spk_emb_input) + 1e-8 + spk_emb_input = spk_emb_input.log() + spk_emb_input = spk_emb_input - torch.mean( + spk_emb_input, dim=-1, keepdim=True) + spk_emb_input = spk_emb_input.permute(0, 2, 1) + + tmp_spk_emb_input = self.spk_model(spk_emb_input) + if isinstance(tmp_spk_emb_input, tuple): + spk_emb_input = tmp_spk_emb_input[-1] + else: + spk_emb_input = tmp_spk_emb_input + predict_speaker_lable = self.pred_linear(spk_emb_input) + + spk_embedding = self.spk_transform(spk_emb_input) + spk_embedding = spk_embedding.unsqueeze(1).unsqueeze(3) + + sep_output = self.separator(subband_feature, spk_embedding, + torch.tensor(nch)) + + sep_subband_spec = [] + for i, mask_func in enumerate(self.mask): + this_output = mask_func(sep_output[:, i]).view( + batch_size * nch, 2, 2, self.band_width[i], -1) + this_mask = this_output[:, 0] * torch.sigmoid( + this_output[:, 1]) # B*nch, 2, K, BW, T + this_mask_real = this_mask[:, 0] # B*nch, K, BW, T + this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T + est_spec_real = (subband_mix_spec[i].real * this_mask_real - + subband_mix_spec[i].imag * this_mask_imag + ) # B*nch, BW, T + est_spec_imag = (subband_mix_spec[i].real * this_mask_imag + + subband_mix_spec[i].imag * this_mask_real + ) # B*nch, BW, T + sep_subband_spec.append(torch.complex(est_spec_real, + est_spec_imag)) + est_spec = torch.cat(sep_subband_spec, 1) # B*nch, F, T + output = torch.istft( + est_spec.view(batch_size * nch, self.enc_dim, -1), + n_fft=self.win, + hop_length=self.stride, + window=torch.hann_window(self.win).to(wav_input.device).type( + wav_input.type()), + length=nsample, + ) + + output = output.view(batch_size, nch, -1) + s = torch.squeeze(output, dim=1) + if torch.is_grad_enabled(): + self_embedding = s.detach() + self_predict_speaker_lable = torch.tensor(0.0).to(self_embedding.device) # dummy + if self.joint_training: + if self.feat_type=='consistent': + with torch.no_grad(): + self_embedding = self.preEmphasis(self_embedding) + self_embedding = self.spk_encoder(self_embedding)+1e-8 + self_embedding = self_embedding.log() + self_embedding = self_embedding - torch.mean(self_embedding, dim=-1, keepdim=True) + self_embedding = self_embedding.permute(0, 2, 1) + + self_tmp_spk_emb_input = self.spk_model(self_embedding) + if isinstance(self_tmp_spk_emb_input,tuple): + self_spk_emb_input = self_tmp_spk_emb_input[-1] + else: + self_spk_emb_input = self_tmp_spk_emb_input + self_predict_speaker_lable = self.pred_linear(self_spk_emb_input) + + self_spk_embedding = self.spk_transform(self_spk_emb_input) + self_spk_embedding = self_spk_embedding.unsqueeze(1).unsqueeze(3) + + self_sep_output = self.separator(subband_feature, self_spk_embedding, torch.tensor(nch)) + + self_sep_subband_spec = [] + for i, mask_func in enumerate(self.mask): + this_output = mask_func(self_sep_output[:, i]).view(batch_size * nch, 2, 2, self.band_width[i], -1) + this_mask = this_output[:, 0] * torch.sigmoid(this_output[:, 1]) # B*nch, 2, K, BW, T + this_mask_real = this_mask[:, 0] # B*nch, K, BW, T + this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T + est_spec_real = subband_mix_spec[i].real * this_mask_real - subband_mix_spec[ + i].imag * this_mask_imag # B*nch, BW, T + est_spec_imag = subband_mix_spec[i].real * this_mask_imag + subband_mix_spec[ + i].imag * this_mask_real # B*nch, BW, T + self_sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag)) + self_est_spec = torch.cat(self_sep_subband_spec, 1) # B*nch, F, T + self_output = torch.istft(self_est_spec.view(batch_size * nch, self.enc_dim, -1), + n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(wav_input.device).type(wav_input.type()), + length=nsample) + + self_output = self_output.view(batch_size, nch, -1) + self_s = torch.squeeze(self_output, dim=1) + + return s,self_s, predict_speaker_lable,self_predict_speaker_lable + + return s, predict_speaker_lable + + +if __name__ == "__main__": + from thop import profile, clever_format + + model = BSRNN_Multi( + spk_emb_dim=256, + sr=16000, + win=512, + stride=128, + feature_dim=128, + num_repeat=6, + spk_fuse_type="additive", + ) + + s = 0 + for param in model.parameters(): + s += np.product(param.size()) + print("# of parameters: " + str(s / 1024.0 / 1024.0)) + x = torch.randn(4, 32000) + spk_embeddings = torch.randn(4, 256) + output = model(x, spk_embeddings) + print(output.shape) + + macs, params = profile(model, inputs=(x, spk_embeddings)) + macs, params = clever_format([macs, params], "%.3f") + print(macs, params) diff --git a/wesep/utils/executor.py b/wesep/utils/executor.py index ca7e95a..73fe414 100644 --- a/wesep/utils/executor.py +++ b/wesep/utils/executor.py @@ -85,8 +85,8 @@ def train( spk_label = spk_label.to(device) with torch.cuda.amp.autocast(enabled=enable_amp): - if SSA_enroll_prob['Single_optimization'] >0: - if SSA_enroll_prob['Single_optimization']>random.random(): + if SSA_enroll_prob >0: + if SSA_enroll_prob>random.random(): with torch.no_grad(): outputs = model(features, enroll) est_speech = outputs[0] @@ -101,7 +101,6 @@ def train( outputs = model(features, enroll) if not isinstance(outputs, (list, tuple)): outputs = [outputs] - loss = 0 for ii in range(len(criterion)): # se_loss_weight: ([position in outputs[0], [1]],