diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 506adebc7..82f9e1cf2 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -166,3 +166,8 @@ python3 ./python-api-examples/offline-decode-files.py \ python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose rm -rf $repo + +# test text2token +git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data + +python3 sherpa-onnx/python/tests/test_text2token.py --verbose diff --git a/.github/workflows/test-python-offline-websocket-server.yaml b/.github/workflows/test-python-offline-websocket-server.yaml index d7ea4dde0..d48d5763d 100644 --- a/.github/workflows/test-python-offline-websocket-server.yaml +++ b/.github/workflows/test-python-offline-websocket-server.yaml @@ -39,7 +39,7 @@ jobs: - name: Install Python dependencies shell: bash run: | - python3 -m pip install --upgrade pip numpy + python3 -m pip install --upgrade pip numpy sentencepiece - name: Install sherpa-onnx shell: bash diff --git a/.github/workflows/test-python-online-websocket-server.yaml b/.github/workflows/test-python-online-websocket-server.yaml index 7616afa35..15f81778e 100644 --- a/.github/workflows/test-python-online-websocket-server.yaml +++ b/.github/workflows/test-python-online-websocket-server.yaml @@ -39,7 +39,7 @@ jobs: - name: Install Python dependencies shell: bash run: | - python3 -m pip install --upgrade pip numpy + python3 -m pip install --upgrade pip numpy sentencepiece - name: Install sherpa-onnx shell: bash diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index 7aef58d70..902f658c4 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -326,6 +326,31 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): ) +def add_hotwords_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + def check_args(args): if not Path(args.tokens).is_file(): raise ValueError(f"{args.tokens} does not exist") @@ -342,6 +367,10 @@ def check_args(args): assert Path(args.decoder).is_file(), args.decoder assert Path(args.joiner).is_file(), args.joiner + if args.hotwords_file != "": + assert args.decoding_method == "modified_beam_search", args.decoding_method + assert Path(args.hotwords_file).is_file(), args.hotwords_file + def get_args(): parser = argparse.ArgumentParser( @@ -351,6 +380,7 @@ def get_args(): add_model_args(parser) add_feature_config_args(parser) add_decoding_args(parser) + add_hotwords_args(parser) parser.add_argument( "--port", @@ -792,6 +822,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: feature_dim=args.feat_dim, decoding_method=args.decoding_method, max_active_paths=args.max_active_paths, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, ) elif args.paraformer: assert len(args.nemo_ctc) == 0, args.nemo_ctc diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index c53e1048a..ad8d1ebaf 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -82,7 +82,6 @@ from typing import List, Tuple import numpy as np -import sentencepiece as spm import sherpa_onnx @@ -98,43 +97,25 @@ def get_args(): ) parser.add_argument( - "--bpe-model", + "--hotwords-file", type=str, default="", help=""" - Path to bpe.model, - Used only when --decoding-method=modified_beam_search - """, - ) + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: - parser.add_argument( - "--modeling-unit", - type=str, - default="char", - help=""" - The type of modeling unit. - Valid values are bpe, bpe+char, char. - Note: the char here means characters in CJK languages. + ▁HE LL O ▁WORLD + 你 好 世 界 """, ) parser.add_argument( - "--contexts", - type=str, - default="", - help=""" - The context list, it is a string containing some words/phrases separated - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". - """, - ) - - parser.add_argument( - "--context-score", + "--hotwords-score", type=float, default=1.5, help=""" - The context score of each token for biasing word/phrase. Used only if - --contexts is given. + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. """, ) @@ -273,25 +254,6 @@ def assert_file_exists(filename: str): "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" ) - -def encode_contexts(args, contexts: List[str]) -> List[List[int]]: - sp = None - if "bpe" in args.modeling_unit: - assert_file_exists(args.bpe_model) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - tokens = {} - with open(args.tokens, "r", encoding="utf-8") as f: - for line in f: - toks = line.strip().split() - assert len(toks) == 2, len(toks) - assert toks[0] not in tokens, f"Duplicate token: {toks} " - tokens[toks[0]] = int(toks[1]) - return sherpa_onnx.encode_contexts( - modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens - ) - - def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: """ Args: @@ -322,7 +284,6 @@ def main(): assert_file_exists(args.tokens) assert args.num_threads > 0, args.num_threads - contexts_list = [] if args.encoder: assert len(args.paraformer) == 0, args.paraformer assert len(args.nemo_ctc) == 0, args.nemo_ctc @@ -330,11 +291,6 @@ def main(): assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] - if contexts: - print(f"Contexts list: {contexts}") - contexts_list = encode_contexts(args, contexts) - assert_file_exists(args.encoder) assert_file_exists(args.decoder) assert_file_exists(args.joiner) @@ -348,7 +304,8 @@ def main(): sample_rate=args.sample_rate, feature_dim=args.feature_dim, decoding_method=args.decoding_method, - context_score=args.context_score, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, debug=args.debug, ) elif args.paraformer: @@ -425,12 +382,7 @@ def main(): samples, sample_rate = read_wave(wave_filename) duration = len(samples) / sample_rate total_duration += duration - if contexts_list: - assert len(args.paraformer) == 0, args.paraformer - assert len(args.nemo_ctc) == 0, args.nemo_ctc - s = recognizer.create_stream(contexts_list=contexts_list) - else: - s = recognizer.create_stream() + s = recognizer.create_stream() s.accept_waveform(sample_rate, samples) streams.append(s) diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index eff854279..cdf7870fb 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -48,7 +48,6 @@ from typing import List, Tuple import numpy as np -import sentencepiece as spm import sherpa_onnx @@ -124,46 +123,25 @@ def get_args(): ) parser.add_argument( - "--bpe-model", + "--hotwords-file", type=str, default="", help=""" - Path to bpe.model, it will be used to tokenize contexts biasing phrases. - Used only when --decoding-method=modified_beam_search - """, - ) - - parser.add_argument( - "--modeling-unit", - type=str, - default="char", - help=""" - The type of modeling unit, it will be used to tokenize contexts biasing phrases. - Valid values are bpe, bpe+char, char. - Note: the char here means characters in CJK languages. - Used only when --decoding-method=modified_beam_search - """, - ) + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: - parser.add_argument( - "--contexts", - type=str, - default="", - help=""" - The context list, it is a string containing some words/phrases separated - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". - Used only when --decoding-method=modified_beam_search + ▁HE LL O ▁WORLD + 你 好 世 界 """, ) parser.add_argument( - "--context-score", + "--hotwords-score", type=float, default=1.5, help=""" - The context score of each token for biasing word/phrase. Used only if - --contexts is given. - Used only when --decoding-method=modified_beam_search + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. """, ) @@ -214,27 +192,6 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: return samples_float32, f.getframerate() -def encode_contexts(args, contexts: List[str]) -> List[List[int]]: - sp = None - if "bpe" in args.modeling_unit: - assert_file_exists(args.bpe_model) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - tokens = {} - with open(args.tokens, "r", encoding="utf-8") as f: - for line in f: - toks = line.strip().split() - assert len(toks) == 2, len(toks) - assert toks[0] not in tokens, f"Duplicate token: {toks} " - tokens[toks[0]] = int(toks[1]) - return sherpa_onnx.encode_contexts( - modeling_unit=args.modeling_unit, - contexts=contexts, - sp=sp, - tokens_table=tokens, - ) - - def main(): args = get_args() assert_file_exists(args.tokens) @@ -258,7 +215,8 @@ def main(): feature_dim=80, decoding_method=args.decoding_method, max_active_paths=args.max_active_paths, - context_score=args.context_score, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, ) elif args.paraformer_encoder: recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( @@ -277,12 +235,6 @@ def main(): print("Started!") start_time = time.time() - contexts_list = [] - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] - if contexts: - print(f"Contexts list: {contexts}") - contexts_list = encode_contexts(args, contexts) - streams = [] total_duration = 0 for wave_filename in args.sound_files: @@ -291,10 +243,7 @@ def main(): duration = len(samples) / sample_rate total_duration += duration - if contexts_list: - s = recognizer.create_stream(contexts_list=contexts_list) - else: - s = recognizer.create_stream() + s = recognizer.create_stream() s.accept_waveform(sample_rate, samples) diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py index 4ed67602f..10ca8cdc7 100755 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -79,6 +79,30 @@ def get_args(): help="Valid values: cpu, cuda, coreml", ) + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + return parser.parse_args() @@ -104,6 +128,8 @@ def create_recognizer(args): rule3_min_utterance_length=300, # it essentially disables this rule decoding_method=args.decoding_method, provider=args.provider, + hotwords_file=agrs.hotwords_file, + hotwords_score=args.hotwords_score, ) return recognizer diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index 9f6be910e..a5aecb67d 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -11,7 +11,6 @@ from pathlib import Path from typing import List -import sentencepiece as spm try: import sounddevice as sd @@ -90,49 +89,29 @@ def get_args(): ) parser.add_argument( - "--bpe-model", + "--hotwords-file", type=str, default="", help=""" - Path to bpe.model, it will be used to tokenize contexts biasing phrases. - Used only when --decoding-method=modified_beam_search - """, - ) + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: - parser.add_argument( - "--modeling-unit", - type=str, - default="char", - help=""" - The type of modeling unit, it will be used to tokenize contexts biasing phrases. - Valid values are bpe, bpe+char, char. - Note: the char here means characters in CJK languages. - Used only when --decoding-method=modified_beam_search + ▁HE LL O ▁WORLD + 你 好 世 界 """, ) parser.add_argument( - "--contexts", - type=str, - default="", - help=""" - The context list, it is a string containing some words/phrases separated - with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". - Used only when --decoding-method=modified_beam_search - """, - ) - - parser.add_argument( - "--context-score", + "--hotwords-score", type=float, default=1.5, help=""" - The context score of each token for biasing word/phrase. Used only if - --contexts is given. - Used only when --decoding-method=modified_beam_search + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. """, ) + return parser.parse_args() @@ -155,32 +134,12 @@ def create_recognizer(args): decoding_method=args.decoding_method, max_active_paths=args.max_active_paths, provider=args.provider, - context_score=args.context_score, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, ) return recognizer -def encode_contexts(args, contexts: List[str]) -> List[List[int]]: - sp = None - if "bpe" in args.modeling_unit: - assert_file_exists(args.bpe_model) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - tokens = {} - with open(args.tokens, "r", encoding="utf-8") as f: - for line in f: - toks = line.strip().split() - assert len(toks) == 2, len(toks) - assert toks[0] not in tokens, f"Duplicate token: {toks} " - tokens[toks[0]] = int(toks[1]) - return sherpa_onnx.encode_contexts( - modeling_unit=args.modeling_unit, - contexts=contexts, - sp=sp, - tokens_table=tokens, - ) - - def main(): args = get_args() @@ -193,12 +152,6 @@ def main(): default_input_device_idx = sd.default.device[0] print(f'Use default device: {devices[default_input_device_idx]["name"]}') - contexts_list = [] - contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] - if contexts: - print(f"Contexts list: {contexts}") - contexts_list = encode_contexts(args, contexts) - recognizer = create_recognizer(args) print("Started! Please speak") @@ -207,10 +160,7 @@ def main(): sample_rate = 48000 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms last_result = "" - if contexts_list: - stream = recognizer.create_stream(contexts_list=contexts_list) - else: - stream = recognizer.create_stream() + stream = recognizer.create_stream() with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: while True: samples, _ = s.read(samples_per_read) # a blocking read diff --git a/python-api-examples/speech-recognition-from-url.py b/python-api-examples/speech-recognition-from-url.py index 1c6c6a1f9..52c5a25a6 100755 --- a/python-api-examples/speech-recognition-from-url.py +++ b/python-api-examples/speech-recognition-from-url.py @@ -87,6 +87,30 @@ def get_args(): """, ) + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + return parser.parse_args() @@ -107,6 +131,8 @@ def create_recognizer(args): rule1_min_trailing_silence=2.4, rule2_min_trailing_silence=1.2, rule3_min_utterance_length=300, # it essentially disables this rule + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, ) return recognizer diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index 33d4e5ee0..b5a37a40e 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -187,6 +187,32 @@ def add_decoding_args(parser: argparse.ArgumentParser): add_modified_beam_search_args(parser) +def add_hotwords_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + + def add_modified_beam_search_args(parser: argparse.ArgumentParser): parser.add_argument( "--num-active-paths", @@ -239,6 +265,7 @@ def get_args(): add_model_args(parser) add_decoding_args(parser) add_endpointing_args(parser) + add_hotwords_args(parser) parser.add_argument( "--port", @@ -343,6 +370,8 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: feature_dim=args.feat_dim, decoding_method=args.decoding_method, max_active_paths=args.num_active_paths, + hotwords_score=args.hotwords_score, + hotwords_file=args.hotwords_file, enable_endpoint_detection=args.use_endpoint != 0, rule1_min_trailing_silence=args.rule1_min_trailing_silence, rule2_min_trailing_silence=args.rule2_min_trailing_silence, diff --git a/scripts/text2token.py b/scripts/text2token.py new file mode 100755 index 000000000..6ba3795f2 --- /dev/null +++ b/scripts/text2token.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 + +""" +This script encode the texts (given line by line through `text`) to tokens and +write the results to the file given by ``output``. + +Usage: +If the tokens_type is bpe: + +python3 ./text2token.py \ + --text texts.txt \ + --tokens tokens.txt \ + --tokens-type bpe \ + --bpe-model bpe.model \ + --output hotwords.txt + +If the tokens_type is cjkchar: + +python3 ./text2token.py \ + --text texts.txt \ + --tokens tokens.txt \ + --tokens-type cjkchar \ + --output hotwords.txt + +If the tokens_type is cjkchar+bpe: + +python3 ./text2token.py \ + --text texts.txt \ + --tokens tokens.txt \ + --tokens-type cjkchar+bpe \ + --bpe-model bpe.model \ + --output hotwords.txt + +""" +import argparse + +from sherpa_onnx import text2token + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + required=True, + help="Path to the input texts", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="The path to tokens.txt.", + ) + + parser.add_argument( + "--tokens-type", + type=str, + required=True, + help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.", + ) + + parser.add_argument( + "--output", + type=str, + required=True, + help="Path where the encoded tokens will be written to.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + texts = [] + with open(args.text, "r", encoding="utf8") as f: + for line in f: + texts.append(line.strip()) + encoded_texts = text2token( + texts, + tokens=args.tokens, + tokens_type=args.tokens_type, + bpe_model=args.bpe_model, + ) + with open(args.output, "w", encoding="utf8") as f: + for txt in encoded_texts: + f.write(" ".join(txt) + "\n") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 1ba2e24b1..97bf3d860 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ def get_package_version(): "numpy", "sentencepiece==0.1.96; python_version < '3.11'", "sentencepiece; python_version >= '3.11'", + "click>=7.1.1", ] @@ -93,6 +94,11 @@ def get_binaries_to_install(): "Programming Language :: Python", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], + entry_points={ + 'console_scripts': [ + 'sherpa-onnx-cli=sherpa_onnx.cli:cli', + ], + }, license="Apache licensed, as found in the LICENSE file", ) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index b9bac58c4..8753af62a 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -72,6 +72,7 @@ set(sources text-utils.cc transpose.cc unbind.cc + utils.cc wave-reader.cc ) diff --git a/sherpa-onnx/csrc/context-graph-test.cc b/sherpa-onnx/csrc/context-graph-test.cc index 0e7e9b5cd..029fecf40 100644 --- a/sherpa-onnx/csrc/context-graph-test.cc +++ b/sherpa-onnx/csrc/context-graph-test.cc @@ -4,11 +4,14 @@ #include "sherpa-onnx/csrc/context-graph.h" +#include // NOLINT #include +#include #include #include #include "gtest/gtest.h" +#include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { @@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) { } } +TEST(ContextGraph, Benchmark) { + std::random_device rd; + std::mt19937 mt(rd()); + std::uniform_int_distribution char_dist(0, 25); + std::uniform_int_distribution len_dist(3, 8); + for (int32_t num = 10; num <= 10000; num *= 10) { + std::vector> contexts; + for (int32_t i = 0; i < num; ++i) { + std::vector tmp; + int32_t word_len = len_dist(mt); + for (int32_t j = 0; j < word_len; ++j) { + tmp.push_back(char_dist(mt)); + } + contexts.push_back(std::move(tmp)); + } + auto start = std::chrono::high_resolution_clock::now(); + auto context_graph = ContextGraph(contexts, 1); + auto stop = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(stop - start); + SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num, + duration.count()); + } +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.h b/sherpa-onnx/csrc/offline-recognizer-impl.h index 15c06efcd..b849de653 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-impl.h @@ -6,6 +6,7 @@ #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ #include +#include #include #if __ANDROID_API__ >= 9 @@ -32,7 +33,7 @@ class OfflineRecognizerImpl { virtual ~OfflineRecognizerImpl() = default; virtual std::unique_ptr CreateStream( - const std::vector> &context_list) const { + const std::string &hotwords) const { SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); exit(-1); } diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 9b7458b3c..3f4e2b05e 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -5,7 +5,9 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ +#include #include +#include // NOLINT #include #include #include @@ -16,6 +18,7 @@ #endif #include "sherpa-onnx/csrc/context-graph.h" +#include "sherpa-onnx/csrc/log.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer.h" @@ -25,6 +28,7 @@ #include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h" #include "sherpa-onnx/csrc/pad-sequence.h" #include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/utils.h" namespace sherpa_onnx { @@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { : config_(config), symbol_table_(config_.model_config.tokens), model_(std::make_unique(config_.model_config)) { + if (!config_.hotwords_file.empty()) { + InitHotwords(); + } if (config_.decoding_method == "greedy_search") { decoder_ = std::make_unique(model_.get()); @@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { #endif std::unique_ptr CreateStream( - const std::vector> &context_list) const override { - // We create context_graph at this level, because we might have default - // context_graph(will be added later if needed) that belongs to the whole - // model rather than each stream. + const std::string &hotwords) const override { + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); + std::istringstream is(hws); + std::vector> current; + if (!EncodeHotwords(is, symbol_table_, ¤t)) { + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", + hotwords.c_str()); + } + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); + auto context_graph = - std::make_shared(context_list, config_.context_score); + std::make_shared(current, config_.hotwords_score); return std::make_unique(config_.feat_config, context_graph); } std::unique_ptr CreateStream() const override { - return std::make_unique(config_.feat_config); + return std::make_unique(config_.feat_config, + hotwords_graph_); } void DecodeStreams(OfflineStream **ss, int32_t n) const override { @@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } } + void InitHotwords() { + // each line in hotwords_file contains space-separated words + + std::ifstream is(config_.hotwords_file); + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, symbol_table_, &hotwords_)) { + SHERPA_ONNX_LOGE("Encode hotwords failed."); + exit(-1); + } + hotwords_graph_ = + std::make_shared(hotwords_, config_.hotwords_score); + } + private: OfflineRecognizerConfig config_; SymbolTable symbol_table_; + std::vector> hotwords_; + ContextGraphPtr hotwords_graph_; std::unique_ptr model_; std::unique_ptr decoder_; std::unique_ptr lm_; diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index d01b6fb88..c42c26871 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -26,7 +26,15 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { po->Register("max-active-paths", &max_active_paths, "Used only when decoding_method is modified_beam_search"); - po->Register("context-score", &context_score, + + po->Register( + "hotwords-file", &hotwords_file, + "The file containing hotwords, one words/phrases per line, and for each" + "phrase the bpe/cjkchar are separated by a space. For example: " + "▁HE LL O ▁WORLD" + "你 好 世 界"); + + po->Register("hotwords-score", &hotwords_score, "The bonus score for each token in context word/phrase. " "Used only when decoding_method is modified_beam_search"); } @@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const { os << "lm_config=" << lm_config.ToString() << ", "; os << "decoding_method=\"" << decoding_method << "\", "; os << "max_active_paths=" << max_active_paths << ", "; - os << "context_score=" << context_score << ")"; + os << "hotwords_file=\"" << hotwords_file << "\", "; + os << "hotwords_score=" << hotwords_score << ")"; return os.str(); } @@ -70,8 +79,8 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) OfflineRecognizer::~OfflineRecognizer() = default; std::unique_ptr OfflineRecognizer::CreateStream( - const std::vector> &context_list) const { - return impl_->CreateStream(context_list); + const std::string &hotwords) const { + return impl_->CreateStream(hotwords); } std::unique_ptr OfflineRecognizer::CreateStream() const { diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 436a3028a..63c23bc24 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -31,7 +31,10 @@ struct OfflineRecognizerConfig { std::string decoding_method = "greedy_search"; int32_t max_active_paths = 4; - float context_score = 1.5; + + std::string hotwords_file; + float hotwords_score = 1.5; + // only greedy_search is implemented // TODO(fangjun): Implement modified_beam_search @@ -40,13 +43,16 @@ struct OfflineRecognizerConfig { const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, const std::string &decoding_method, - int32_t max_active_paths, float context_score) + int32_t max_active_paths, + const std::string &hotwords_file, + float hotwords_score) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), decoding_method(decoding_method), max_active_paths(max_active_paths), - context_score(context_score) {} + hotwords_file(hotwords_file), + hotwords_score(hotwords_score) {} void Register(ParseOptions *po); bool Validate() const; @@ -69,9 +75,17 @@ class OfflineRecognizer { /// Create a stream for decoding. std::unique_ptr CreateStream() const; - /// Create a stream for decoding. + /** Create a stream for decoding. + * + * @param The hotwords for this string, it might contain several hotwords, + * the hotwords are separated by "/". In each of the hotwords, there + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). + * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: + * + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" + */ std::unique_ptr CreateStream( - const std::vector> &context_list) const; + const std::string &hotwords) const; /** Decode a single stream * diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h index 515c9d9e8..db07ffa53 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-impl.h @@ -6,6 +6,7 @@ #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ #include +#include #include #include "sherpa-onnx/csrc/macros.h" @@ -29,7 +30,7 @@ class OnlineRecognizerImpl { virtual std::unique_ptr CreateStream() const = 0; virtual std::unique_ptr CreateStream( - const std::vector> &contexts) const { + const std::string &hotwords) const { SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); exit(-1); } diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index e08993dc1..9af9a7800 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -7,6 +7,8 @@ #include #include +#include // NOLINT +#include #include #include @@ -20,6 +22,7 @@ #include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" #include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/utils.h" namespace sherpa_onnx { @@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { model_(OnlineTransducerModel::Create(config.model_config)), sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { + if (!config_.hotwords_file.empty()) { + InitHotwords(); + } if (sym_.contains("")) { unk_id_ = sym_[""]; } @@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { #endif std::unique_ptr CreateStream() const override { - auto stream = std::make_unique(config_.feat_config); + auto stream = + std::make_unique(config_.feat_config, hotwords_graph_); InitOnlineStream(stream.get()); return stream; } std::unique_ptr CreateStream( - const std::vector> &contexts) const override { - // We create context_graph at this level, because we might have default - // context_graph(will be added later if needed) that belongs to the whole - // model rather than each stream. + const std::string &hotwords) const override { + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); + std::istringstream is(hws); + std::vector> current; + if (!EncodeHotwords(is, sym_, ¤t)) { + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", + hotwords.c_str()); + } + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); auto context_graph = - std::make_shared(contexts, config_.context_score); + std::make_shared(current, config_.hotwords_score); auto stream = std::make_unique(config_.feat_config, context_graph); InitOnlineStream(stream.get()); @@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { s->Reset(); } + void InitHotwords() { + // each line in hotwords_file contains space-separated words + + std::ifstream is(config_.hotwords_file); + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, sym_, &hotwords_)) { + SHERPA_ONNX_LOGE("Encode hotwords failed."); + exit(-1); + } + hotwords_graph_ = + std::make_shared(hotwords_, config_.hotwords_score); + } + private: void InitOnlineStream(OnlineStream *stream) const { auto r = decoder_->GetEmptyResult(); @@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { private: OnlineRecognizerConfig config_; + std::vector> hotwords_; + ContextGraphPtr hotwords_graph_; std::unique_ptr model_; std::unique_ptr lm_; std::unique_ptr decoder_; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index f72e7fc42..c3b187665 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -57,9 +57,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "True to enable endpoint detection. False to disable it."); po->Register("max-active-paths", &max_active_paths, "beam size used in modified beam search."); - po->Register("context-score", &context_score, + po->Register("hotwords-score", &hotwords_score, "The bonus score for each token in context word/phrase. " "Used only when decoding_method is modified_beam_search"); + po->Register( + "hotwords-file", &hotwords_file, + "The file containing hotwords, one words/phrases per line, and for each" + "phrase the bpe/cjkchar are separated by a space. For example: " + "▁HE LL O ▁WORLD" + "你 好 世 界"); po->Register("decoding-method", &decoding_method, "decoding method," "now support greedy_search and modified_beam_search."); @@ -87,7 +93,8 @@ std::string OnlineRecognizerConfig::ToString() const { os << "endpoint_config=" << endpoint_config.ToString() << ", "; os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; os << "max_active_paths=" << max_active_paths << ", "; - os << "context_score=" << context_score << ", "; + os << "hotwords_score=" << hotwords_score << ", "; + os << "hotwords_file=\"" << hotwords_file << "\", "; os << "decoding_method=\"" << decoding_method << "\")"; return os.str(); @@ -109,8 +116,8 @@ std::unique_ptr OnlineRecognizer::CreateStream() const { } std::unique_ptr OnlineRecognizer::CreateStream( - const std::vector> &context_list) const { - return impl_->CreateStream(context_list); + const std::string &hotwords) const { + return impl_->CreateStream(hotwords); } bool OnlineRecognizer::IsReady(OnlineStream *s) const { diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index cbac9d08f..3aa838026 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -78,8 +78,10 @@ struct OnlineRecognizerConfig { // used only for modified_beam_search int32_t max_active_paths = 4; + /// used only for modified_beam_search - float context_score = 1.5; + float hotwords_score = 1.5; + std::string hotwords_file; OnlineRecognizerConfig() = default; @@ -89,14 +91,16 @@ struct OnlineRecognizerConfig { const EndpointConfig &endpoint_config, bool enable_endpoint, const std::string &decoding_method, - int32_t max_active_paths, float context_score) + int32_t max_active_paths, + const std::string &hotwords_file, float hotwords_score) : feat_config(feat_config), model_config(model_config), endpoint_config(endpoint_config), enable_endpoint(enable_endpoint), decoding_method(decoding_method), max_active_paths(max_active_paths), - context_score(context_score) {} + hotwords_score(hotwords_score), + hotwords_file(hotwords_file) {} void Register(ParseOptions *po); bool Validate() const; @@ -119,9 +123,16 @@ class OnlineRecognizer { /// Create a stream for decoding. std::unique_ptr CreateStream() const; - // Create a stream with context phrases - std::unique_ptr CreateStream( - const std::vector> &context_list) const; + /** Create a stream for decoding. + * + * @param The hotwords for this string, it might contain several hotwords, + * the hotwords are separated by "/". In each of the hotwords, there + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). + * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: + * + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" + */ + std::unique_ptr CreateStream(const std::string &hotwords) const; /** * Return true if the given stream has enough frames for decoding. diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc new file mode 100644 index 000000000..a437abe2a --- /dev/null +++ b/sherpa-onnx/csrc/utils.cc @@ -0,0 +1,54 @@ +// sherpa-onnx/csrc/utils.cc +// +// Copyright 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/utils.h" + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/log.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, + std::vector> *hotwords) { + hotwords->clear(); + std::vector tmp; + std::string line; + std::string word; + + while (std::getline(is, line)) { + std::istringstream iss(line); + std::vector syms; + while (iss >> word) { + if (word.size() >= 3) { + // For BPE-based models, we replace ▁ with a space + // Unicode 9601, hex 0x2581, utf8 0xe29681 + const uint8_t *p = reinterpret_cast(word.c_str()); + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + word = word.replace(0, 3, " "); + } + } + if (symbol_table.contains(word)) { + int32_t number = symbol_table[word]; + tmp.push_back(number); + } else { + SHERPA_ONNX_LOGE( + "Cannot find ID for hotword %s at line: %s. (Hint: words on " + "the " + "same line are separated by spaces)", + word.c_str(), line.c_str()); + return false; + } + } + hotwords->push_back(std::move(tmp)); + } + return true; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h new file mode 100644 index 000000000..19d75d873 --- /dev/null +++ b/sherpa-onnx/csrc/utils.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/utils.h +// +// Copyright 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_UTILS_H_ +#define SHERPA_ONNX_CSRC_UTILS_H_ + +#include +#include + +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +/* Encode the hotwords in an input stream to be tokens ids. + * + * @param is The input stream, it contains several lines, one hotword for each + * line. For each hotword, the tokens (cjkchar or bpe) are separated + * by spaces. + * @param symbol_table The tokens table mapping symbols to ids. All the symbols + * in the stream should be in the symbol_table, if not this + * function returns fasle. + * + * @@param hotwords The encoded ids to be written to. + * + * @return If all the symbols from ``is`` are in the symbol_table, returns true + * otherwise returns false. + */ +bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, + std::vector> *hotwords); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_UTILS_H_ diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 462d8ba38..dbeec96ce 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -16,17 +16,19 @@ static void PybindOfflineRecognizerConfig(py::module *m) { py::class_(*m, "OfflineRecognizerConfig") .def(py::init(), + const std::string &, int32_t, const std::string &, float>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), py::arg("decoding_method") = "greedy_search", - py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5) + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 1.5) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("max_active_paths", &PyClass::max_active_paths) - .def_readwrite("context_score", &PyClass::context_score) + .def_readwrite("hotwords_file", &PyClass::hotwords_file) + .def_readwrite("hotwords_score", &PyClass::hotwords_score) .def("__str__", &PyClass::ToString); } @@ -40,11 +42,10 @@ void PybindOfflineRecognizer(py::module *m) { [](const PyClass &self) { return self.CreateStream(); }) .def( "create_stream", - [](PyClass &self, - const std::vector> &contexts_list) { - return self.CreateStream(contexts_list); + [](PyClass &self, const std::string &hotwords) { + return self.CreateStream(hotwords); }, - py::arg("contexts_list")) + py::arg("hotwords")) .def("decode_stream", &PyClass::DecodeStream) .def("decode_streams", [](const PyClass &self, std::vector ss) { diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 7e37a87c8..8699e56d0 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -21,8 +21,8 @@ void PybindOnlineModelConfig(py::module *m) { using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") .def(py::init(), + const OnlineParaformerModelConfig &, const std::string &, + int32_t, bool, const std::string &, const std::string &>(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index c130d87c9..68e97b60a 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -29,18 +29,20 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::class_(*m, "OnlineRecognizerConfig") .def(py::init(), + const std::string &, int32_t, const std::string &, float>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths") = 4, py::arg("context_score") = 0) + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("max_active_paths", &PyClass::max_active_paths) - .def_readwrite("context_score", &PyClass::context_score) + .def_readwrite("hotwords_file", &PyClass::hotwords_file) + .def_readwrite("hotwords_score", &PyClass::hotwords_score) .def("__str__", &PyClass::ToString); } @@ -55,11 +57,10 @@ void PybindOnlineRecognizer(py::module *m) { [](const PyClass &self) { return self.CreateStream(); }) .def( "create_stream", - [](PyClass &self, - const std::vector> &contexts_list) { - return self.CreateStream(contexts_list); + [](PyClass &self, const std::string &hotwords) { + return self.CreateStream(hotwords); }, - py::arg("contexts_list")) + py::arg("hotwords")) .def("is_ready", &PyClass::IsReady) .def("decode_stream", &PyClass::DecodeStream) .def("decode_streams", diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 0f1e23f52..b21156c7b 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -4,4 +4,4 @@ from .offline_recognizer import OfflineRecognizer from .online_recognizer import OnlineRecognizer -from .utils import encode_contexts +from .utils import text2token diff --git a/sherpa-onnx/python/sherpa_onnx/cli.py b/sherpa-onnx/python/sherpa_onnx/cli.py new file mode 100644 index 000000000..971e724bb --- /dev/null +++ b/sherpa-onnx/python/sherpa_onnx/cli.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 Xiaomi Corporation + +import logging +import click +from pathlib import Path +from sherpa_onnx import text2token + + +@click.group() +def cli(): + """ + The shell entry point to sherpa-onnx. + """ + logging.basicConfig( + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + level=logging.INFO, + ) + + +@cli.command(name="text2token") +@click.argument("input", type=click.Path(exists=True, dir_okay=False)) +@click.argument("output", type=click.Path()) +@click.option( + "--tokens", + type=str, + required=True, + help="The path to tokens.txt.", +) +@click.option( + "--tokens-type", + type=str, + required=True, + help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe", +) +@click.option( + "--bpe-model", + type=str, + help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.", +) +def encode_text( + input: Path, output: Path, tokens: Path, tokens_type: str, bpe_model: Path +): + """ + Encode the texts given by the INPUT to tokens and write the results to the OUTPUT. + """ + texts = [] + with open(input, "r", encoding="utf8") as f: + for line in f: + texts.append(line.strip()) + encoded_texts = text2token( + texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model + ) + with open(output, "w", encoding="utf8") as f: + for txt in encoded_texts: + f.write(" ".join(txt) + "\n") diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 26ee9b27f..6b737be95 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -43,7 +43,8 @@ def from_transducer( feature_dim: int = 80, decoding_method: str = "greedy_search", max_active_paths: int = 4, - context_score: float = 1.5, + hotwords_file: str = "", + hotwords_score: float = 1.5, debug: bool = False, provider: str = "cpu", ): @@ -105,7 +106,8 @@ def from_transducer( feat_config=feat_config, model_config=model_config, decoding_method=decoding_method, - context_score=context_score, + hotwords_file=hotwords_file, + hotwords_score=hotwords_score, ) self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config @@ -379,11 +381,11 @@ def from_tdnn_ctc( self.config = recognizer_config return self - def create_stream(self, contexts_list: Optional[List[List[int]]] = None): - if contexts_list is None: + def create_stream(self, hotwords: Optional[str] = None): + if hotwords is None: return self.recognizer.create_stream() else: - return self.recognizer.create_stream(contexts_list) + return self.recognizer.create_stream(hotwords) def decode_stream(self, s: OfflineStream): self.recognizer.decode_stream(s) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 55e789ba0..e4f991a04 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -42,7 +42,8 @@ def from_transducer( rule3_min_utterance_length: float = 20.0, decoding_method: str = "greedy_search", max_active_paths: int = 4, - context_score: float = 1.5, + hotwords_score: float = 1.5, + hotwords_file: str = "", provider: str = "cpu", model_type: str = "", ): @@ -138,7 +139,8 @@ def from_transducer( enable_endpoint=enable_endpoint_detection, decoding_method=decoding_method, max_active_paths=max_active_paths, - context_score=context_score, + hotwords_score=hotwords_score, + hotwords_file=hotwords_file, ) self.recognizer = _Recognizer(recognizer_config) @@ -248,11 +250,11 @@ def from_paraformer( self.config = recognizer_config return self - def create_stream(self, contexts_list: Optional[List[List[int]]] = None): - if contexts_list is None: + def create_stream(self, hotwords: Optional[str] = None): + if hotwords is None: return self.recognizer.create_stream() else: - return self.recognizer.create_stream(contexts_list) + return self.recognizer.create_stream(hotwords) def decode_stream(self, s: OnlineStream): self.recognizer.decode_stream(s) diff --git a/sherpa-onnx/python/sherpa_onnx/utils.py b/sherpa-onnx/python/sherpa_onnx/utils.py index dbe6d91e4..a02a8e4c0 100644 --- a/sherpa-onnx/python/sherpa_onnx/utils.py +++ b/sherpa-onnx/python/sherpa_onnx/utils.py @@ -1,74 +1,95 @@ -from typing import Dict, List, Optional +# Copyright (c) 2023 Xiaomi Corporation +import re +from pathlib import Path +from typing import List, Optional, Union -def encode_contexts( - modeling_unit: str, - contexts: List[str], - sp: Optional["SentencePieceProcessor"] = None, - tokens_table: Optional[Dict[str, int]] = None, -) -> List[List[int]]: +import sentencepiece as spm + + +def text2token( + texts: List[str], + tokens: str, + tokens_type: str = "cjkchar", + bpe_model: Optional[str] = None, + output_ids: bool = False, +) -> List[List[Union[str, int]]]: """ - Encode the given contexts (a list of string) to a list of a list of token ids. + Encode the given texts (a list of string) to a list of a list of tokens. Args: - modeling_unit: - The valid values are bpe, char, bpe+char. - Note: char here means characters in CJK languages, not English like languages. - contexts: + texts: The given contexts list (a list of string). - sp: - An instance of SentencePieceProcessor. - tokens_table: - The tokens_table containing the tokens and the corresponding ids. + tokens: + The path of the tokens.txt. + tokens_type: + The valid values are cjkchar, bpe, cjkchar+bpe. + bpe_model: + The path of the bpe model. Only required when tokens_type is bpe or + cjkchar+bpe. + output_ids: + True to output token ids otherwise tokens. Returns: - Return the contexts_list, it is a list of a list of token ids. + Return the encoded texts, it is a list of a list of token ids if output_ids + is True, or it is a list of list of tokens. """ - contexts_list = [] - if "bpe" in modeling_unit: - assert sp is not None - if "char" in modeling_unit: - assert tokens_table is not None - assert len(tokens_table) > 0, len(tokens_table) + assert Path(tokens).is_file(), f"File not exists, {tokens}" + tokens_table = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, len(toks) + assert toks[0] not in tokens_table, f"Duplicate token: {toks} " + tokens_table[toks[0]] = int(toks[1]) - if "char" == modeling_unit: - for context in contexts: - assert ' ' not in context - ids = [ - tokens_table[txt] if txt in tokens_table else tokens_table[""] - for txt in context - ] - contexts_list.append(ids) - elif "bpe" == modeling_unit: - contexts_list = sp.encode(contexts, out_type=int) - else: - assert modeling_unit == "bpe+char", modeling_unit + if "bpe" in tokens_type: + assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}" + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + texts_list: List[List[str]] = [] + + if tokens_type == "cjkchar": + texts_list = [list("".join(text.split())) for text in texts] + elif tokens_type == "bpe": + texts_list = sp.encode(texts, out_type=str) + else: + assert ( + tokens_type == "cjkchar+bpe" + ), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}" # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) pattern = re.compile(r"([\u4e00-\u9fff])") - for context in contexts: + for text in texts: # Example: # txt = "你好 ITS'S OKAY 的" # chars = ["你", "好", " ITS'S OKAY ", "的"] - chars = pattern.split(context.upper()) + chars = pattern.split(text) mix_chars = [w for w in chars if len(w.strip()) > 0] - ids = [] + text_list = [] for ch_or_w in mix_chars: # ch_or_w is a single CJK charater(i.e., "你"), do nothing. if pattern.fullmatch(ch_or_w) is not None: - ids.append( - tokens_table[ch_or_w] - if ch_or_w in tokens_table - else tokens_table[""] - ) + text_list.append(ch_or_w) # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), # encode ch_or_w using bpe_model. else: - for p in sp.encode_as_pieces(ch_or_w): - ids.append( - tokens_table[p] - if p in tokens_table - else tokens_table[""] - ) - contexts_list.append(ids) - return contexts_list + text_list += sp.encode_as_pieces(ch_or_w) + texts_list.append(text_list) + + result: List[List[Union[int, str]]] = [] + for text in texts_list: + text_list = [] + contain_oov = False + for txt in text: + if txt in tokens_table: + text_list.append(tokens_table[txt] if output_ids else txt) + else: + print(f"OOV token : {txt}, skipping text : {text}.") + contain_oov = True + break + if contain_oov: + continue + else: + result.append(text_list) + return result diff --git a/sherpa-onnx/python/tests/CMakeLists.txt b/sherpa-onnx/python/tests/CMakeLists.txt index ff9b8c9eb..4fd285293 100644 --- a/sherpa-onnx/python/tests/CMakeLists.txt +++ b/sherpa-onnx/python/tests/CMakeLists.txt @@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source) COMMAND "${PYTHON_EXECUTABLE}" "${CMAKE_CURRENT_SOURCE_DIR}/${source}" + WORKING_DIRECTORY + ${CMAKE_CURRENT_SOURCE_DIR} ) get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY) set_property(TEST ${name} - PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$:$ENV{PYTHONPATH}" + PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_onnx_path}:$:$ENV{PYTHONPATH}" ) endfunction() @@ -21,6 +23,7 @@ set(py_test_files test_offline_recognizer.py test_online_recognizer.py test_online_transducer_model_config.py + test_text2token.py ) foreach(source IN LISTS py_test_files) diff --git a/sherpa-onnx/python/tests/test_text2token.py b/sherpa-onnx/python/tests/test_text2token.py new file mode 100644 index 000000000..7bc065b6a --- /dev/null +++ b/sherpa-onnx/python/tests/test_text2token.py @@ -0,0 +1,121 @@ +# sherpa-onnx/python/tests/test_text2token.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_text2token_py + +import unittest +from pathlib import Path + +import sherpa_onnx + +d = "/tmp/sherpa-test-data" +# Please refer to +# https://github.com/pkufool/sherpa-test-data +# to download test data for testing + + +class TestText2Token(unittest.TestCase): + def test_bpe(self): + tokens = f"{d}/text2token/tokens_en.txt" + bpe_model = f"{d}/text2token/bpe_en.model" + + if not Path(tokens).is_file() or not Path(bpe_model).is_file(): + print( + f"No test data found, skipping test_bpe().\n" + f"You can download the test data by: \n" + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + + texts = ["HELLO WORLD", "I LOVE YOU"] + encoded_texts = sherpa_onnx.text2token( + texts, + tokens=tokens, + tokens_type="bpe", + bpe_model=bpe_model, + ) + assert encoded_texts == [ + ["▁HE", "LL", "O", "▁WORLD"], + ["▁I", "▁LOVE", "▁YOU"], + ], encoded_texts + + encoded_ids = sherpa_onnx.text2token( + texts, + tokens=tokens, + tokens_type="bpe", + bpe_model=bpe_model, + output_ids=True, + ) + assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids + + def test_cjkchar(self): + tokens = f"{d}/text2token/tokens_cn.txt" + + if not Path(tokens).is_file(): + print( + f"No test data found, skipping test_cjkchar().\n" + f"You can download the test data by: \n" + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + + texts = ["世界人民大团结", "中国 VS 美国"] + encoded_texts = sherpa_onnx.text2token( + texts, tokens=tokens, tokens_type="cjkchar" + ) + assert encoded_texts == [ + ["世", "界", "人", "民", "大", "团", "结"], + ["中", "国", "V", "S", "美", "国"], + ], encoded_texts + encoded_ids = sherpa_onnx.text2token( + texts, + tokens=tokens, + tokens_type="cjkchar", + output_ids=True, + ) + assert encoded_ids == [ + [379, 380, 72, 874, 93, 1251, 489], + [262, 147, 3423, 2476, 21, 147], + ], encoded_ids + + def test_cjkchar_bpe(self): + tokens = f"{d}/text2token/tokens_mix.txt" + bpe_model = f"{d}/text2token/bpe_mix.model" + + if not Path(tokens).is_file() or not Path(bpe_model).is_file(): + print( + f"No test data found, skipping test_cjkchar_bpe().\n" + f"You can download the test data by: \n" + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + + texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"] + encoded_texts = sherpa_onnx.text2token( + texts, + tokens=tokens, + tokens_type="cjkchar+bpe", + bpe_model=bpe_model, + ) + assert encoded_texts == [ + ["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"], + ["中", "国", "▁GO", "ES", "▁WITH", "美", "国"], + ], encoded_texts + encoded_ids = sherpa_onnx.text2token( + texts, + tokens=tokens, + tokens_type="cjkchar+bpe", + bpe_model=bpe_model, + output_ids=True, + ) + assert encoded_ids == [ + [1368, 1392, 557, 680, 275, 178, 475], + [685, 736, 275, 178, 179, 921, 736], + ], encoded_ids + + +if __name__ == "__main__": + unittest.main()