Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C++ runtime and Python APIs for Moonshine models #1473

Merged
merged 7 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions .github/scripts/test-offline-moonshine.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env bash

set -e

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

export GIT_CLONE_PROTECTION_ACTIVE=false

echo "EXE is $EXE"
echo "PATH: $PATH"

which $EXE

names=(
tiny
base
)

for name in ${names[@]}; do
log "------------------------------------------------------------"
log "Run $name"
log "------------------------------------------------------------"

repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-$name.tar.bz2
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-$name-en-int8.tar.bz2
curl -SL -O $repo_url
tar xvf sherpa-onnx-moonshine-$name-en-int8.tar.bz2
rm sherpa-onnx-moonshine-$name-en-int8.tar.bz2
repo=sherpa-onnx-moonshine-$name-en-int8
log "Start testing ${repo_url}"

log "test int8 onnx"

time $EXE \
--moonshine-preprocessor=$repo/preprocess.onnx \
--moonshine-encoder=$repo/encode.int8.onnx \
--moonshine-uncached-decoder=$repo/uncached_decode.int8.onnx \
--moonshine-cached-decoder=$repo/cached_decode.int8.onnx \
--tokens=$repo/tokens.txt \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf $repo
done
10 changes: 10 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test offline Moonshine"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2

python3 ./python-api-examples/offline-moonshine-decode-files.py

rm -rf sherpa-onnx-moonshine-tiny-en-int8

log "test offline speaker diarization"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Expand Down
13 changes: 13 additions & 0 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*

- name: Test offline Moonshine
if: matrix.build_type != 'Debug'
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline

readelf -d build/bin/sherpa-onnx-offline

.github/scripts/test-offline-moonshine.sh
du -h -d1 .

- name: Test offline CTC
shell: bash
run: |
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx

- name: Test offline Moonshine
if: matrix.build_type != 'Debug'
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline

.github/scripts/test-offline-moonshine.sh

- name: Test C++ API
shell: bash
run: |
Expand Down Expand Up @@ -243,8 +252,6 @@ jobs:

.github/scripts/test-offline-whisper.sh



- name: Test online transducer
shell: bash
run: |
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ jobs:
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*

- name: Test offline Moonshine for windows x64
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe

.github/scripts/test-offline-moonshine.sh

- name: Test C++ API
shell: bash
run: |
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ jobs:
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*

- name: Test offline Moonshine for windows x86
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe

.github/scripts/test-offline-moonshine.sh

- name: Test C++ API
shell: bash
run: |
Expand Down
117 changes: 108 additions & 9 deletions python-api-examples/generate-subtitles.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,19 @@
--feature-dim=80 \
/path/to/test.mp4

(3) For Whisper models
(3) For Moonshine models

./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \
--num-threads=2 \
/path/to/test.mp4

(4) For Whisper models

./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
Expand All @@ -58,7 +70,7 @@
--num-threads=2 \
/path/to/test.mp4

(4) For SenseVoice CTC models
(5) For SenseVoice CTC models

./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
Expand All @@ -68,7 +80,7 @@
/path/to/test.mp4


(5) For WeNet CTC models
(6) For WeNet CTC models

./python-api-examples/generate-subtitles.py \
--silero-vad-model=/path/to/silero_vad.onnx \
Expand All @@ -83,6 +95,7 @@
used in this file.
"""
import argparse
import datetime as dt
import shutil
import subprocess
import sys
Expand Down Expand Up @@ -157,7 +170,7 @@ def get_args():
parser.add_argument(
"--num-threads",
type=int,
default=1,
default=2,
help="Number of threads for neural network computation",
)

Expand Down Expand Up @@ -208,6 +221,34 @@ def get_args():
""",
)

parser.add_argument(
"--moonshine-preprocessor",
default="",
type=str,
help="Path to moonshine preprocessor model",
)

parser.add_argument(
"--moonshine-encoder",
default="",
type=str,
help="Path to moonshine encoder model",
)

parser.add_argument(
"--moonshine-uncached-decoder",
default="",
type=str,
help="Path to moonshine uncached decoder model",
)

parser.add_argument(
"--moonshine-cached-decoder",
default="",
type=str,
help="Path to moonshine cached decoder model",
)

parser.add_argument(
"--decoding-method",
type=str,
Expand Down Expand Up @@ -263,6 +304,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
Expand All @@ -284,6 +331,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

assert_file_exists(args.paraformer)

Expand All @@ -300,6 +353,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

assert_file_exists(args.sense_voice)
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
Expand All @@ -312,6 +371,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.wenet_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

assert_file_exists(args.wenet_ctc)

Expand All @@ -327,6 +392,12 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
assert (
len(args.moonshine_uncached_decoder) == 0
), args.moonshine_uncached_decoder
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder

recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
Expand All @@ -339,6 +410,22 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
task=args.whisper_task,
tail_paddings=args.whisper_tail_paddings,
)
elif args.moonshine_preprocessor:
assert_file_exists(args.moonshine_preprocessor)
assert_file_exists(args.moonshine_encoder)
assert_file_exists(args.moonshine_uncached_decoder)
assert_file_exists(args.moonshine_cached_decoder)

recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
preprocessor=args.moonshine_preprocessor,
encoder=args.moonshine_encoder,
uncached_decoder=args.moonshine_uncached_decoder,
cached_decoder=args.moonshine_cached_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
raise ValueError("Please specify at least one model")

Expand Down Expand Up @@ -424,28 +511,32 @@ def main():
segment_list = []

print("Started!")
start_t = dt.datetime.now()
num_processed_samples = 0

is_silence = False
is_eof = False
# TODO(fangjun): Support multithreads
while True:
# *2 because int16_t has two bytes
data = process.stdout.read(frames_per_read * 2)
if not data:
if is_silence:
if is_eof:
break
is_silence = True
# The converted audio file does not have a mute data of 1 second or more at the end, which will result in the loss of the last segment data
is_eof = True
# pad 1 second at the end of the file for the VAD
data = np.zeros(1 * args.sample_rate, dtype=np.int16)

samples = np.frombuffer(data, dtype=np.int16)
samples = samples.astype(np.float32) / 32768

num_processed_samples += samples.shape[0]

buffer = np.concatenate([buffer, samples])
while len(buffer) > window_size:
vad.accept_waveform(buffer[:window_size])
buffer = buffer[window_size:]

if is_silence:
if is_eof:
vad.flush()

streams = []
Expand All @@ -471,6 +562,11 @@ def main():
seg.text = stream.result.text
segment_list.append(seg)

end_t = dt.datetime.now()
elapsed_seconds = (end_t - start_t).total_seconds()
duration = num_processed_samples / 16000
rtf = elapsed_seconds / duration

srt_filename = Path(args.sound_file).with_suffix(".srt")
with open(srt_filename, "w", encoding="utf-8") as f:
for i, seg in enumerate(segment_list):
Expand All @@ -479,6 +575,9 @@ def main():
print("", file=f)

print(f"Saved to {srt_filename}")
print(f"Audio duration:\t{duration:.3f} s")
print(f"Elapsed:\t{elapsed_seconds:.3f} s")
print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print("Done!")


Expand Down
Loading
Loading