Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add diarization recipe v3 #347

Merged
merged 13 commits into from
Aug 20, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pre-commit install # for clean and tidy code
```

## 🔥 News
* 2024.08.20: Update diarization recipe for VoxConverse dataset by leveraging umap dimensionality reduction and hdbscan clustering, see [#347](https://github.com/wenet-e2e/wespeaker/pull/347).
* 2024.08.18: Support using ssl pre-trained models as the frontend. The [WavLM recipe](https://github.com/wenet-e2e/wespeaker/blob/master/examples/voxceleb/v2/run_wavlm.sh) is also provided, see [#344](https://github.com/wenet-e2e/wespeaker/pull/344).
* 2024.05.15: Add support for [quality-aware score calibration](https://arxiv.org/pdf/2211.00815), see [#320](https://github.com/wenet-e2e/wespeaker/pull/320).
* 2024.04.25: Add support for the gemini-dfresnet model, see [#291](https://github.com/wenet-e2e/wespeaker/pull/291).
Expand Down
34 changes: 34 additions & 0 deletions examples/voxconverse/v3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
## Overview

* We suggest to run this recipe on a gpu-available machine, with onnxruntime-gpu supported.
* Dataset: voxconverse_dev that consists of 216 utterances
* Speaker model: ResNet34 model pretrained by wespeaker
* Refer to [voxceleb sv recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2)
* [pretrained model path](https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx)
* Speaker activity detection model: oracle SAD (from ground truth annotation) or system SAD (VAD model pretrained by silero, https://github.com/snakers4/silero-vad)
* Clustering method: umap dimensionality reduction + hdbscan clustering
* Metric: DER = MISS + FALSE ALARM + SPEAKER CONFUSION (%)

## Results

* Dev set

| system | MISS | FA | SC | DER |
|:---|:---:|:---:|:---:|:---:|
| This repo (with oracle SAD) | 2.3 | 0.0 | 1.3 | 3.6 |
| This repo (with system SAD) | 3.4 | 0.6 | 1.4 | 5.4 |
| DIHARD 2019 baseline [^1] | 11.1 | 1.4 | 11.3 | 23.8 |
| DIHARD 2019 baseline w/ SE [^1] | 9.3 | 1.3 | 9.7 | 20.2 |
| (SyncNet ASD only) [^1] | 2.2 | 4.1 | 4.0 | 10.4 |
| (AVSE ASD only) [^1] | 2.0 | 5.9 | 4.6 | 12.4 |
| (proposed) [^1] | 2.4 | 2.3 | 3.0 | 7.7 |

* Test set

| system | MISS | FA | SC | DER |
|:---|:---:|:---:|:---:|:---:|
| This repo (with oracle SAD) | 1.6 | 0.0 | 1.9 | 3.5 |
| This repo (with system SAD) | 3.8 | 1.7 | 1.8 | 7.4 |


[^1]: Spot the conversation: speaker diarisation in the wild, https://arxiv.org/pdf/2007.01216.pdf
1 change: 1 addition & 0 deletions examples/voxconverse/v3/local
1 change: 1 addition & 0 deletions examples/voxconverse/v3/path.sh
186 changes: 186 additions & 0 deletions examples/voxconverse/v3/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/bin/bash
# Copyright (c) 2022-2023 Xu Xiang
# 2022 Zhengyang Chen (chenzhengyang117@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

. ./path.sh || exit 1

stage=-1
stop_stage=-1
sad_type="oracle"
partition="dev"

# do cmn on the sub-segment or on the vad segment
subseg_cmn=true
# whether print the evaluation result for each file
get_each_file_res=1

. tools/parse_options.sh

# Prerequisite
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mkdir -p external_tools

# [1] Download evaluation toolkit
wget -c https://github.com/usnistgov/SCTK/archive/refs/tags/v2.4.12.zip -O external_tools/SCTK-v2.4.12.zip
unzip -o external_tools/SCTK-v2.4.12.zip -d external_tools

# [3] Download ResNet34 speaker model pretrained by WeSpeaker Team
mkdir -p pretrained_models

wget -c https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx -O pretrained_models/voxceleb_resnet34_LM.onnx
fi


# Download VoxConverse dev/test audios and the corresponding annotations
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
mkdir -p data

# Download annotations for dev and test sets (version 0.0.3)
wget -c https://github.com/joonson/voxconverse/archive/refs/heads/master.zip -O data/voxconverse_master.zip
unzip -o data/voxconverse_master.zip -d data

# Download annotations from VoxSRC-23 validation toolkit (looks like version 0.0.2)
# cd data && git clone https://github.com/JaesungHuh/VoxSRC2023.git --recursive && cd -

# Download dev audios
mkdir -p data/dev

#wget --no-check-certificate -c https://mm.kaist.ac.kr/datasets/voxconverse/data/voxconverse_dev_wav.zip -O data/voxconverse_dev_wav.zip
# The above url may not be reachable, you can try the link below.
# This url is from https://github.com/joonson/voxconverse/blob/master/README.md
wget --no-check-certificate -c https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_dev_wav.zip -O data/voxconverse_dev_wav.zip
unzip -o data/voxconverse_dev_wav.zip -d data/dev

# Create wav.scp for dev audios
ls `pwd`/data/dev/audio/*.wav | awk -F/ '{print substr($NF, 1, length($NF)-4), $0}' > data/dev/wav.scp

# Test audios
mkdir -p data/test

#wget --no-check-certificate -c https://mm.kaist.ac.kr/datasets/voxconverse/data/voxconverse_test_wav.zip -O data/voxconverse_test_wav.zip
# The above url may not be reachable, you can try the link below.
# This url is from https://github.com/joonson/voxconverse/blob/master/README.md
wget --no-check-certificate -c https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_test_wav.zip -O data/voxconverse_test_wav.zip
unzip -o data/voxconverse_test_wav.zip -d data/test

# Create wav.scp for test audios
ls `pwd`/data/test/voxconverse_test_wav/*.wav | awk -F/ '{print substr($NF, 1, length($NF)-4), $0}' > data/test/wav.scp
fi


# Voice activity detection
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# Set VAD min duration
min_duration=0.255

if [[ "x${sad_type}" == "xoracle" ]]; then
# Oracle SAD: handling overlapping or too short regions in ground truth RTTM
while read -r utt wav_path; do
python3 wespeaker/diar/make_oracle_sad.py \
--rttm data/voxconverse-master/${partition}/${utt}.rttm \
--min-duration $min_duration
done < data/${partition}/wav.scp > data/${partition}/oracle_sad
fi

if [[ "x${sad_type}" == "xsystem" ]]; then
# System SAD: applying 'silero' VAD
python3 wespeaker/diar/make_system_sad.py \
--scp data/${partition}/wav.scp \
--min-duration $min_duration > data/${partition}/system_sad
fi
fi


# Extract fbank features
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then

[ -d "exp/${sad_type}_sad_fbank" ] && rm -r exp/${sad_type}_sad_fbank

echo "Make Fbank features and store it under exp/${sad_type}_sad_fbank"
echo "..."
bash local/make_fbank.sh \
--scp data/${partition}/wav.scp \
--segments data/${partition}/${sad_type}_sad \
--store_dir exp/${partition}_${sad_type}_sad_fbank \
--subseg_cmn ${subseg_cmn} \
--nj 24
fi

# Extract embeddings
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then

[ -d "exp/${sad_type}_sad_embedding" ] && rm -r exp/${sad_type}_sad_embedding

echo "Extract embeddings and store it under exp/${sad_type}_sad_embedding"
echo "..."
bash local/extract_emb.sh \
--scp exp/${partition}_${sad_type}_sad_fbank/fbank.scp \
--pretrained_model pretrained_models/voxceleb_resnet34_LM.onnx \
--device cuda \
--store_dir exp/${partition}_${sad_type}_sad_embedding \
--batch_size 96 \
--frame_shift 10 \
--window_secs 1.5 \
--period_secs 0.75 \
--subseg_cmn ${subseg_cmn} \
--nj 1
fi


# Applying umap clustering algorithm
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then

[ -f "exp/umap_cluster/${partition}_${sad_type}_sad_labels" ] && rm exp/umap_cluster/${partition}_${sad_type}_sad_labels

echo "Doing umap clustering and store the result in exp/umap_cluster/${partition}_${sad_type}_sad_labels"
echo "..."
python3 wespeaker/diar/umap_clusterer.py \
--scp exp/${partition}_${sad_type}_sad_embedding/emb.scp \
--output exp/umap_cluster/${partition}_${sad_type}_sad_labels
fi


# Convert labels to RTTMs
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
python3 wespeaker/diar/make_rttm.py \
--labels exp/umap_cluster/${partition}_${sad_type}_sad_labels \
--channel 1 > exp/umap_cluster/${partition}_${sad_type}_sad_rttm
fi


# Evaluate the result
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
ref_dir=data/voxconverse-master/
#ref_dir=data/VoxSRC2023/voxconverse/
echo -e "Get the DER results\n..."
perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \
-c 0.25 \
-r <(cat ${ref_dir}/${partition}/*.rttm) \
-s exp/umap_cluster/${partition}_${sad_type}_sad_rttm 2>&1 | tee exp/umap_cluster/${partition}_${sad_type}_sad_res

if [ ${get_each_file_res} -eq 1 ];then
single_file_res_dir=exp/umap_cluster/${partition}_${sad_type}_single_file_res
mkdir -p $single_file_res_dir
echo -e "\nGet the DER results for each file and the results will be stored underd ${single_file_res_dir}\n..."

awk '{print $2}' exp/umap_cluster/${partition}_${sad_type}_sad_rttm | sort -u | while read file_name; do
perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \
-c 0.25 \
-r <(cat ${ref_dir}/${partition}/${file_name}.rttm) \
-s <(grep "${file_name}" exp/umap_cluster/${partition}_${sad_type}_sad_rttm) > ${single_file_res_dir}/${partition}_${file_name}_res
done
echo "Done!"
fi
fi
1 change: 1 addition & 0 deletions examples/voxconverse/v3/tools
1 change: 1 addition & 0 deletions examples/voxconverse/v3/wespeaker
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ pypeln==0.4.9
silero-vad
pre-commit==3.5.0
s3prl
hdbscan==0.8.37
umap-learn==0.5.6
32 changes: 9 additions & 23 deletions wespeaker/cli/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from wespeaker.cli.utils import get_args
from wespeaker.models.speaker_model import get_speaker_model
from wespeaker.utils.checkpoint import load_checkpoint
from wespeaker.diar.spectral_clusterer import cluster
from wespeaker.diar.umap_clusterer import cluster
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JiJiJiang I am not sure whether we should change the client script.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, just keep it as the better one.

from wespeaker.diar.extract_emb import subsegment
from wespeaker.diar.make_rttm import merge_segments
from wespeaker.utils.utils import set_seed
Expand All @@ -47,6 +47,7 @@ def __init__(self, model_dir: str):
self.model = get_speaker_model(
configs['model'])(**configs['model_args'])
load_checkpoint(self.model, model_path)
self.model.eval()
self.vad = load_silero_vad()
self.table = {}
self.resample_rate = 16000
Expand All @@ -55,9 +56,6 @@ def __init__(self, model_dir: str):
self.wavform_norm = False

# diarization parmas
self.diar_num_spks = None
self.diar_min_num_spks = 1
self.diar_max_num_spks = 20
self.diar_min_duration = 0.255
self.diar_window_secs = 1.5
self.diar_period_secs = 0.75
Expand All @@ -83,18 +81,12 @@ def set_gpu(self, device_id: int):
self.model = self.model.to(self.device)

def set_diarization_params(self,
num_spks=None,
min_num_spks=1,
max_num_spks=20,
min_duration: float = 0.255,
window_secs: float = 1.5,
period_secs: float = 0.75,
frame_shift: int = 10,
batch_size: int = 32,
subseg_cmn: bool = True):
self.diar_num_spks = num_spks
self.diar_min_num_spks = min_num_spks
self.diar_max_num_spks = max_num_spks
self.diar_min_duration = min_duration
self.diar_window_secs = window_secs
self.diar_period_secs = period_secs
Expand Down Expand Up @@ -127,10 +119,10 @@ def extract_embedding_feats(self, fbanks, batch_size, subseg_cmn):
fbanks_array = torch.from_numpy(fbanks_array).to(self.device)
for i in tqdm(range(0, fbanks_array.shape[0], batch_size)):
batch_feats = fbanks_array[i:i + batch_size]
# _, batch_embs = self.model(batch_feats)
batch_embs = self.model(batch_feats)
batch_embs = batch_embs[-1] if isinstance(batch_embs,
tuple) else batch_embs
with torch.no_grad():
batch_embs = self.model(batch_feats)
batch_embs = batch_embs[-1] if isinstance(batch_embs,
tuple) else batch_embs
embeddings.append(batch_embs.detach().cpu().numpy())
embeddings = np.vstack(embeddings)
return embeddings
Expand Down Expand Up @@ -162,7 +154,7 @@ def extract_embedding(self, audio_path: str):
cmn=True)
feats = feats.unsqueeze(0)
feats = feats.to(self.device)
self.model.eval()

with torch.no_grad():
outputs = self.model(feats)
outputs = outputs[-1] if isinstance(outputs, tuple) else outputs
Expand Down Expand Up @@ -251,10 +243,7 @@ def diarize(self, audio_path: str, utt: str = "unk"):

# 4. cluster
subseg2label = []
labels = cluster(embeddings,
num_spks=self.diar_num_spks,
min_num_spks=self.diar_min_num_spks,
max_num_spks=self.diar_max_num_spks)
labels = cluster(embeddings)
for (_subseg, _label) in zip(subsegs, labels):
# b, e = process_seg_id(_subseg, frame_shift=self.diar_frame_shift)
# subseg2label.append([b, e, _label])
Expand Down Expand Up @@ -316,10 +305,7 @@ def main():
model.set_resample_rate(args.resample_rate)
model.set_vad(args.vad)
model.set_gpu(args.gpu)
model.set_diarization_params(num_spks=args.diar_num_spks,
min_num_spks=args.diar_min_num_spks,
max_num_spks=args.diar_max_num_spks,
min_duration=args.diar_min_duration,
model.set_diarization_params(min_duration=args.diar_min_duration,
window_secs=args.diar_window_secs,
period_secs=args.diar_period_secs,
frame_shift=args.diar_frame_shift,
Expand Down
12 changes: 0 additions & 12 deletions wespeaker/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,6 @@ def get_args():
help='output file to save speaker embedding '
'or save diarization result')
# diarization params
parser.add_argument('--diar_num_spks',
type=int,
default=None,
help='number of speakers')
parser.add_argument('--diar_min_num_spks',
type=int,
default=1,
help='minimum number of speakers')
parser.add_argument('--diar_max_num_spks',
type=int,
default=20,
help='maximum number of speakers')
parser.add_argument('--diar_min_duration',
type=float,
default=0.255,
Expand Down
1 change: 1 addition & 0 deletions wespeaker/diar/extract_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def init_session(source, device):
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.log_severity_level = 0
session = ort.InferenceSession(source,
sess_options=opts,
providers=providers)
Expand Down
Loading
Loading