diff --git a/wespeaker/cli/hub.py b/wespeaker/cli/hub.py index 384c74a6..79a301b6 100644 --- a/wespeaker/cli/hub.py +++ b/wespeaker/cli/hub.py @@ -72,6 +72,8 @@ class Hub(object): Assets = { "chinese": "cnceleb_resnet34.tar.gz", "english": "voxceleb_resnet221_LM.tar.gz", + "campplus": "campplus_cn_common_200k.tar.gz", + "eres2net": "eres2net_cn_commom_200k.tar.gz", } def __init__(self) -> None: diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py index 5dec4b2c..a4491a5d 100644 --- a/wespeaker/cli/speaker.py +++ b/wespeaker/cli/speaker.py @@ -52,6 +52,7 @@ def __init__(self, model_dir: str): self.resample_rate = 16000 self.apply_vad = False self.device = torch.device('cpu') + self.wavform_norm = False # diarization parmas self.diar_num_spks = None @@ -64,6 +65,9 @@ def __init__(self, model_dir: str): self.diar_batch_size = 32 self.diar_subseg_cmn = True + def set_wavform_norm(self, wavform_norm: bool): + self.wavform_norm = wavform_norm + def set_resample_rate(self, resample_rate: int): self.resample_rate = resample_rate @@ -132,7 +136,8 @@ def extract_embedding_feats(self, fbanks, batch_size, subseg_cmn): return embeddings def extract_embedding(self, audio_path: str): - pcm, sample_rate = torchaudio.load(audio_path, normalize=False) + pcm, sample_rate = torchaudio.load(audio_path, + normalize=self.wavform_norm) if self.apply_vad: # TODO(Binbin Zhang): Refine the segments logic, here we just # suppose there is only silence at the start/end of the speech @@ -160,7 +165,6 @@ def extract_embedding(self, audio_path: str): feats = feats.to(self.device) self.model.eval() with torch.no_grad(): - # _, outputs = self.model(feats) outputs = self.model(feats) outputs = outputs[-1] if isinstance(outputs, tuple) else outputs embedding = outputs[0].to(torch.device('cpu')) @@ -301,7 +305,14 @@ def load_model_local(model_dir: str) -> Speaker: def main(): args = get_args() if args.pretrain == "": - model = load_model(args.language) + if args.campplus: + model = load_model("campplus") + model.set_wavform_norm(True) + elif args.eres2net: + model = load_model("eres2net") + model.set_wavform_norm(True) + else: + model = load_model(args.language) else: model = load_model_local(args.pretrain) model.set_resample_rate(args.resample_rate) diff --git a/wespeaker/cli/utils.py b/wespeaker/cli/utils.py index 8cb1b78c..8289114b 100644 --- a/wespeaker/cli/utils.py +++ b/wespeaker/cli/utils.py @@ -37,6 +37,16 @@ def get_args(): ], default='chinese', help='language type') + parser.add_argument( + '--campplus', + action='store_true', + help='whether to use the damo/speech_campplus_sv_zh-cn_16k-common model' + ) + parser.add_argument( + '--eres2net', + action='store_true', + help='whether to use the damo/speech_eres2net_sv_zh-cn_16k-common model' + ) parser.add_argument('-p', '--pretrain', type=str, diff --git a/wespeaker/models/eres2net.py b/wespeaker/models/eres2net.py index 5147c7b2..9e56219a 100644 --- a/wespeaker/models/eres2net.py +++ b/wespeaker/models/eres2net.py @@ -103,14 +103,20 @@ def forward(self, x, ds_y): class BasicBlockERes2Net(nn.Module): - expansion = 2 - def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + def __init__(self, + in_planes, + planes, + stride=1, + baseWidth=32, + scale=2, + expansion=2): super(BasicBlockERes2Net, self).__init__() width = int(math.floor(planes * (baseWidth / 64.0))) self.conv1 = conv1x1(in_planes, width * scale, stride) self.bn1 = nn.BatchNorm2d(width * scale) self.nums = scale + self.expansion = expansion convs = [] bns = [] @@ -162,14 +168,20 @@ def forward(self, x): class BasicBlockERes2Net_diff_AFF(nn.Module): - expansion = 2 - def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + def __init__(self, + in_planes, + planes, + stride=1, + baseWidth=32, + scale=2, + expansion=2): super(BasicBlockERes2Net_diff_AFF, self).__init__() width = int(math.floor(planes * (baseWidth / 64.0))) self.conv1 = conv1x1(in_planes, width * scale, stride) self.bn1 = nn.BatchNorm2d(width * scale) self.nums = scale + self.expansion = expansion # to meet the torch.jit.script export requirements self.conv2_1 = conv3x3(width, width) @@ -232,6 +244,9 @@ class ERes2Net(nn.Module): def __init__(self, m_channels, num_blocks, + baseWidth=32, + scale=2, + expansion=2, block=BasicBlockERes2Net, block_fuse=BasicBlockERes2Net_diff_AFF, feat_dim=80, @@ -244,6 +259,7 @@ def __init__(self, self.embed_dim = embed_dim self.stats_dim = int(feat_dim / 8) * m_channels * 8 self.two_emb_layer = two_emb_layer + self.expansion = expansion self.conv1 = nn.Conv2d(1, m_channels, @@ -255,48 +271,59 @@ def __init__(self, self.layer1 = self._make_layer(block, m_channels, num_blocks[0], - stride=1) + stride=1, + baseWidth=baseWidth, + scale=scale, + expansion=expansion) self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], - stride=2) + stride=2, + baseWidth=baseWidth, + scale=scale, + expansion=expansion) self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], - stride=2) + stride=2, + baseWidth=baseWidth, + scale=scale, + expansion=expansion) self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], - stride=2) + stride=2, + baseWidth=baseWidth, + scale=scale, + expansion=expansion) # Downsampling module for each layer - self.layer1_downsample = nn.Conv2d(m_channels * 2, - m_channels * 4, + self.layer1_downsample = nn.Conv2d(m_channels * expansion, + m_channels * expansion * 2, kernel_size=3, stride=2, padding=1, bias=False) - self.layer2_downsample = nn.Conv2d(m_channels * 4, - m_channels * 8, + self.layer2_downsample = nn.Conv2d(m_channels * expansion * 2, + m_channels * expansion * 4, kernel_size=3, padding=1, stride=2, bias=False) - self.layer3_downsample = nn.Conv2d(m_channels * 8, - m_channels * 16, + self.layer3_downsample = nn.Conv2d(m_channels * expansion * 4, + m_channels * expansion * 8, kernel_size=3, padding=1, stride=2, bias=False) # Bottom-up fusion module - self.fuse_mode12 = AFF(channels=m_channels * 4) - self.fuse_mode123 = AFF(channels=m_channels * 8) - self.fuse_mode1234 = AFF(channels=m_channels * 16) + self.fuse_mode12 = AFF(channels=m_channels * expansion * 2) + self.fuse_mode123 = AFF(channels=m_channels * expansion * 4) + self.fuse_mode1234 = AFF(channels=m_channels * expansion * 8) self.pool = getattr(pooling_layers, - pooling_func)(in_dim=self.stats_dim * - block.expansion) + pooling_func)(in_dim=self.stats_dim * expansion) self.pool_out_dim = self.pool.get_out_dim() self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim) if self.two_emb_layer: @@ -306,12 +333,21 @@ def __init__(self, self.seg_bn_1 = nn.Identity() self.seg_2 = nn.Identity() - def _make_layer(self, block, planes, num_blocks, stride): + def _make_layer(self, + block, + planes, + num_blocks, + stride, + baseWidth=32, + scale=2, + expansion=2): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion + layers.append( + block(self.in_planes, planes, stride, baseWidth, scale, + expansion)) + self.in_planes = planes * self.expansion return nn.Sequential(*layers) def forward(self, x): @@ -362,6 +398,23 @@ def ERes2Net34_Large(feat_dim, two_emb_layer=two_emb_layer) +def ERes2Net34_aug(feat_dim, + embed_dim, + pooling_func='TSTP', + two_emb_layer=False, + expansion=4, + baseWidth=24, + scale=3): + return ERes2Net(64, [3, 4, 6, 3], + expansion=expansion, + baseWidth=baseWidth, + scale=scale, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer) + + if __name__ == '__main__': x = torch.zeros(1, 200, 80) model = ERes2Net34_Base(feat_dim=80, embed_dim=512, two_emb_layer=False)