Skip to content

Commit

Permalink
[cli] support campplus_200k and eres2net_200k models of damo (#281)
Browse files Browse the repository at this point in the history
* [cli] support campplus_200k_common and eres2net_200k_common models of damo

* [cli] fix typo
  • Loading branch information
cdliang11 authored Mar 1, 2024
1 parent 170eefc commit 31921e9
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 25 deletions.
2 changes: 2 additions & 0 deletions wespeaker/cli/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class Hub(object):
Assets = {
"chinese": "cnceleb_resnet34.tar.gz",
"english": "voxceleb_resnet221_LM.tar.gz",
"campplus": "campplus_cn_common_200k.tar.gz",
"eres2net": "eres2net_cn_commom_200k.tar.gz",
}

def __init__(self) -> None:
Expand Down
17 changes: 14 additions & 3 deletions wespeaker/cli/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, model_dir: str):
self.resample_rate = 16000
self.apply_vad = False
self.device = torch.device('cpu')
self.wavform_norm = False

# diarization parmas
self.diar_num_spks = None
Expand All @@ -64,6 +65,9 @@ def __init__(self, model_dir: str):
self.diar_batch_size = 32
self.diar_subseg_cmn = True

def set_wavform_norm(self, wavform_norm: bool):
self.wavform_norm = wavform_norm

def set_resample_rate(self, resample_rate: int):
self.resample_rate = resample_rate

Expand Down Expand Up @@ -132,7 +136,8 @@ def extract_embedding_feats(self, fbanks, batch_size, subseg_cmn):
return embeddings

def extract_embedding(self, audio_path: str):
pcm, sample_rate = torchaudio.load(audio_path, normalize=False)
pcm, sample_rate = torchaudio.load(audio_path,
normalize=self.wavform_norm)
if self.apply_vad:
# TODO(Binbin Zhang): Refine the segments logic, here we just
# suppose there is only silence at the start/end of the speech
Expand Down Expand Up @@ -160,7 +165,6 @@ def extract_embedding(self, audio_path: str):
feats = feats.to(self.device)
self.model.eval()
with torch.no_grad():
# _, outputs = self.model(feats)
outputs = self.model(feats)
outputs = outputs[-1] if isinstance(outputs, tuple) else outputs
embedding = outputs[0].to(torch.device('cpu'))
Expand Down Expand Up @@ -301,7 +305,14 @@ def load_model_local(model_dir: str) -> Speaker:
def main():
args = get_args()
if args.pretrain == "":
model = load_model(args.language)
if args.campplus:
model = load_model("campplus")
model.set_wavform_norm(True)
elif args.eres2net:
model = load_model("eres2net")
model.set_wavform_norm(True)
else:
model = load_model(args.language)
else:
model = load_model_local(args.pretrain)
model.set_resample_rate(args.resample_rate)
Expand Down
10 changes: 10 additions & 0 deletions wespeaker/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def get_args():
],
default='chinese',
help='language type')
parser.add_argument(
'--campplus',
action='store_true',
help='whether to use the damo/speech_campplus_sv_zh-cn_16k-common model'
)
parser.add_argument(
'--eres2net',
action='store_true',
help='whether to use the damo/speech_eres2net_sv_zh-cn_16k-common model'
)
parser.add_argument('-p',
'--pretrain',
type=str,
Expand Down
97 changes: 75 additions & 22 deletions wespeaker/models/eres2net.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,20 @@ def forward(self, x, ds_y):


class BasicBlockERes2Net(nn.Module):
expansion = 2

def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
def __init__(self,
in_planes,
planes,
stride=1,
baseWidth=32,
scale=2,
expansion=2):
super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = conv1x1(in_planes, width * scale, stride)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
self.expansion = expansion

convs = []
bns = []
Expand Down Expand Up @@ -162,14 +168,20 @@ def forward(self, x):


class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 2

def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
def __init__(self,
in_planes,
planes,
stride=1,
baseWidth=32,
scale=2,
expansion=2):
super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = conv1x1(in_planes, width * scale, stride)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
self.expansion = expansion

# to meet the torch.jit.script export requirements
self.conv2_1 = conv3x3(width, width)
Expand Down Expand Up @@ -232,6 +244,9 @@ class ERes2Net(nn.Module):
def __init__(self,
m_channels,
num_blocks,
baseWidth=32,
scale=2,
expansion=2,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
feat_dim=80,
Expand All @@ -244,6 +259,7 @@ def __init__(self,
self.embed_dim = embed_dim
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.expansion = expansion

self.conv1 = nn.Conv2d(1,
m_channels,
Expand All @@ -255,48 +271,59 @@ def __init__(self,
self.layer1 = self._make_layer(block,
m_channels,
num_blocks[0],
stride=1)
stride=1,
baseWidth=baseWidth,
scale=scale,
expansion=expansion)
self.layer2 = self._make_layer(block,
m_channels * 2,
num_blocks[1],
stride=2)
stride=2,
baseWidth=baseWidth,
scale=scale,
expansion=expansion)
self.layer3 = self._make_layer(block_fuse,
m_channels * 4,
num_blocks[2],
stride=2)
stride=2,
baseWidth=baseWidth,
scale=scale,
expansion=expansion)
self.layer4 = self._make_layer(block_fuse,
m_channels * 8,
num_blocks[3],
stride=2)
stride=2,
baseWidth=baseWidth,
scale=scale,
expansion=expansion)

# Downsampling module for each layer
self.layer1_downsample = nn.Conv2d(m_channels * 2,
m_channels * 4,
self.layer1_downsample = nn.Conv2d(m_channels * expansion,
m_channels * expansion * 2,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.layer2_downsample = nn.Conv2d(m_channels * 4,
m_channels * 8,
self.layer2_downsample = nn.Conv2d(m_channels * expansion * 2,
m_channels * expansion * 4,
kernel_size=3,
padding=1,
stride=2,
bias=False)
self.layer3_downsample = nn.Conv2d(m_channels * 8,
m_channels * 16,
self.layer3_downsample = nn.Conv2d(m_channels * expansion * 4,
m_channels * expansion * 8,
kernel_size=3,
padding=1,
stride=2,
bias=False)

# Bottom-up fusion module
self.fuse_mode12 = AFF(channels=m_channels * 4)
self.fuse_mode123 = AFF(channels=m_channels * 8)
self.fuse_mode1234 = AFF(channels=m_channels * 16)
self.fuse_mode12 = AFF(channels=m_channels * expansion * 2)
self.fuse_mode123 = AFF(channels=m_channels * expansion * 4)
self.fuse_mode1234 = AFF(channels=m_channels * expansion * 8)

self.pool = getattr(pooling_layers,
pooling_func)(in_dim=self.stats_dim *
block.expansion)
pooling_func)(in_dim=self.stats_dim * expansion)
self.pool_out_dim = self.pool.get_out_dim()
self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim)
if self.two_emb_layer:
Expand All @@ -306,12 +333,21 @@ def __init__(self,
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()

def _make_layer(self, block, planes, num_blocks, stride):
def _make_layer(self,
block,
planes,
num_blocks,
stride,
baseWidth=32,
scale=2,
expansion=2):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
layers.append(
block(self.in_planes, planes, stride, baseWidth, scale,
expansion))
self.in_planes = planes * self.expansion
return nn.Sequential(*layers)

def forward(self, x):
Expand Down Expand Up @@ -362,6 +398,23 @@ def ERes2Net34_Large(feat_dim,
two_emb_layer=two_emb_layer)


def ERes2Net34_aug(feat_dim,
embed_dim,
pooling_func='TSTP',
two_emb_layer=False,
expansion=4,
baseWidth=24,
scale=3):
return ERes2Net(64, [3, 4, 6, 3],
expansion=expansion,
baseWidth=baseWidth,
scale=scale,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer)


if __name__ == '__main__':
x = torch.zeros(1, 200, 80)
model = ERes2Net34_Base(feat_dim=80, embed_dim=512, two_emb_layer=False)
Expand Down

0 comments on commit 31921e9

Please sign in to comment.