From 68d7a921dd8f6d55994851f931287dfe6e4f4571 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 20 Dec 2024 18:15:11 +0800 Subject: [PATCH] Support decoding with byte-level BPE models. --- scripts/bbpe/.gitignore | 1 + scripts/bbpe/generate_bbpe_table.py | 66 +++++++++++++ sherpa-onnx/csrc/CMakeLists.txt | 3 +- sherpa-onnx/csrc/bbpe.cc | 60 ++++++++++++ sherpa-onnx/csrc/bbpe.h | 15 +++ .../csrc/offline-recognizer-ctc-impl.h | 7 +- .../csrc/offline-recognizer-transducer-impl.h | 6 +- sherpa-onnx/csrc/online-recognizer-ctc-impl.h | 11 ++- .../csrc/online-recognizer-transducer-impl.h | 11 ++- sherpa-onnx/csrc/symbol-table.cc | 92 ++++++++++++++++++- sherpa-onnx/csrc/symbol-table.h | 5 + 11 files changed, 267 insertions(+), 10 deletions(-) create mode 100644 scripts/bbpe/.gitignore create mode 100755 scripts/bbpe/generate_bbpe_table.py create mode 100644 sherpa-onnx/csrc/bbpe.cc create mode 100644 sherpa-onnx/csrc/bbpe.h diff --git a/scripts/bbpe/.gitignore b/scripts/bbpe/.gitignore new file mode 100644 index 000000000..fa966a067 --- /dev/null +++ b/scripts/bbpe/.gitignore @@ -0,0 +1 @@ +bbpe.cc diff --git a/scripts/bbpe/generate_bbpe_table.py b/scripts/bbpe/generate_bbpe_table.py new file mode 100755 index 000000000..c0416b27a --- /dev/null +++ b/scripts/bbpe/generate_bbpe_table.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_bpe.py#L28 +# and +# https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py +# +# Caution: The PRINTABLE_LATIN from fairseq is different from PRINTABLE_BASE_CHARS from icefall + +import re + +BPE_UNK = chr(8263) +PRINTABLE_BASE_CHARS = ( + list(range(256, 287 + 1)) + + list(range(32, 126 + 1)) + + list(range(288, 305 + 1)) + + list(range(308, 318 + 1)) + + list(range(321, 328 + 1)) + + list(range(330, 382 + 1)) + + list(range(384, 422 + 1)) +) + + +BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)} +BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} +BCHAR_TO_BYTE[BPE_UNK] = 32 # map unk to space + + +def main(): + s = "" + s += "// sherpa-onnx/csrc/bbpe.cc\n" + s += "//\n" + s += "// Copyright (c) 2024 Xiaomi Corporation\n" + s += "\n" + s += "// Auto-generated! DO NOT EDIT\n" + s += "\n" + s += '#include "sherpa-onnx/csrc/bbpe.h"\n' + s += "\n" + s += "#include \n" + s += "#include \n" + s += "\n" + s += "const std::unordered_map &GetByteBpeTable() {\n" + s += " static const std::unordered_map table = {\n" + + s += " " + for i, (k, v) in enumerate(BCHAR_TO_BYTE.items()): + s += "{" + if k in ["\\", '"']: + s += f'"\{k}", {v}' + else: + s += f'"{k}", {v}' + s += "}, " + if i > 0 and i % 7 == 0: + s += "\n" + s += " " + s += "};\n" + s += "\n" + s += " return table\n;" + s += "}\n" + + with open("bbpe.cc", "w", encoding="utf-8") as f: + f.write(s) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index b3c1617e3..6bfcd2a98 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -12,6 +12,7 @@ endif() set(sources base64-decode.cc + bbpe.cc cat.cc circular-buffer.cc context-graph.cc @@ -78,11 +79,11 @@ set(sources online-stream.cc online-transducer-decoder.cc online-transducer-greedy-search-decoder.cc + online-transducer-greedy-search-nemo-decoder.cc online-transducer-model-config.cc online-transducer-model.cc online-transducer-modified-beam-search-decoder.cc online-transducer-nemo-model.cc - online-transducer-greedy-search-nemo-decoder.cc online-wenet-ctc-model-config.cc online-wenet-ctc-model.cc online-zipformer-transducer-model.cc diff --git a/sherpa-onnx/csrc/bbpe.cc b/sherpa-onnx/csrc/bbpe.cc new file mode 100644 index 000000000..d4696132e --- /dev/null +++ b/sherpa-onnx/csrc/bbpe.cc @@ -0,0 +1,60 @@ +// sherpa-onnx/csrc/bbpe.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +// Auto-generated! DO NOT EDIT + +#include "sherpa-onnx/csrc/bbpe.h" + +#include +#include + +const std::unordered_map &GetByteBpeTable() { + static const std::unordered_map table = { + {"Ā", 0}, {"ā", 1}, {"Ă", 2}, {"ă", 3}, {"Ą", 4}, {"ą", 5}, + {"Ć", 6}, {"ć", 7}, {"Ĉ", 8}, {"ĉ", 9}, {"Ċ", 10}, {"ċ", 11}, + {"Č", 12}, {"č", 13}, {"Ď", 14}, {"ď", 15}, {"Đ", 16}, {"đ", 17}, + {"Ē", 18}, {"ē", 19}, {"Ĕ", 20}, {"ĕ", 21}, {"Ė", 22}, {"ė", 23}, + {"Ę", 24}, {"ę", 25}, {"Ě", 26}, {"ě", 27}, {"Ĝ", 28}, {"ĝ", 29}, + {"Ğ", 30}, {"ğ", 31}, {" ", 32}, {"!", 33}, {"\"", 34}, {"#", 35}, + {"$", 36}, {"%", 37}, {"&", 38}, {"'", 39}, {"(", 40}, {")", 41}, + {"*", 42}, {"+", 43}, {",", 44}, {"-", 45}, {".", 46}, {"/", 47}, + {"0", 48}, {"1", 49}, {"2", 50}, {"3", 51}, {"4", 52}, {"5", 53}, + {"6", 54}, {"7", 55}, {"8", 56}, {"9", 57}, {":", 58}, {";", 59}, + {"<", 60}, {"=", 61}, {">", 62}, {"?", 63}, {"@", 64}, {"A", 65}, + {"B", 66}, {"C", 67}, {"D", 68}, {"E", 69}, {"F", 70}, {"G", 71}, + {"H", 72}, {"I", 73}, {"J", 74}, {"K", 75}, {"L", 76}, {"M", 77}, + {"N", 78}, {"O", 79}, {"P", 80}, {"Q", 81}, {"R", 82}, {"S", 83}, + {"T", 84}, {"U", 85}, {"V", 86}, {"W", 87}, {"X", 88}, {"Y", 89}, + {"Z", 90}, {"[", 91}, {"\\", 92}, {"]", 93}, {"^", 94}, {"_", 95}, + {"`", 96}, {"a", 97}, {"b", 98}, {"c", 99}, {"d", 100}, {"e", 101}, + {"f", 102}, {"g", 103}, {"h", 104}, {"i", 105}, {"j", 106}, {"k", 107}, + {"l", 108}, {"m", 109}, {"n", 110}, {"o", 111}, {"p", 112}, {"q", 113}, + {"r", 114}, {"s", 115}, {"t", 116}, {"u", 117}, {"v", 118}, {"w", 119}, + {"x", 120}, {"y", 121}, {"z", 122}, {"{", 123}, {"|", 124}, {"}", 125}, + {"~", 126}, {"Ġ", 127}, {"ġ", 128}, {"Ģ", 129}, {"ģ", 130}, {"Ĥ", 131}, + {"ĥ", 132}, {"Ħ", 133}, {"ħ", 134}, {"Ĩ", 135}, {"ĩ", 136}, {"Ī", 137}, + {"ī", 138}, {"Ĭ", 139}, {"ĭ", 140}, {"Į", 141}, {"į", 142}, {"İ", 143}, + {"ı", 144}, {"Ĵ", 145}, {"ĵ", 146}, {"Ķ", 147}, {"ķ", 148}, {"ĸ", 149}, + {"Ĺ", 150}, {"ĺ", 151}, {"Ļ", 152}, {"ļ", 153}, {"Ľ", 154}, {"ľ", 155}, + {"Ł", 156}, {"ł", 157}, {"Ń", 158}, {"ń", 159}, {"Ņ", 160}, {"ņ", 161}, + {"Ň", 162}, {"ň", 163}, {"Ŋ", 164}, {"ŋ", 165}, {"Ō", 166}, {"ō", 167}, + {"Ŏ", 168}, {"ŏ", 169}, {"Ő", 170}, {"ő", 171}, {"Œ", 172}, {"œ", 173}, + {"Ŕ", 174}, {"ŕ", 175}, {"Ŗ", 176}, {"ŗ", 177}, {"Ř", 178}, {"ř", 179}, + {"Ś", 180}, {"ś", 181}, {"Ŝ", 182}, {"ŝ", 183}, {"Ş", 184}, {"ş", 185}, + {"Š", 186}, {"š", 187}, {"Ţ", 188}, {"ţ", 189}, {"Ť", 190}, {"ť", 191}, + {"Ŧ", 192}, {"ŧ", 193}, {"Ũ", 194}, {"ũ", 195}, {"Ū", 196}, {"ū", 197}, + {"Ŭ", 198}, {"ŭ", 199}, {"Ů", 200}, {"ů", 201}, {"Ű", 202}, {"ű", 203}, + {"Ų", 204}, {"ų", 205}, {"Ŵ", 206}, {"ŵ", 207}, {"Ŷ", 208}, {"ŷ", 209}, + {"Ÿ", 210}, {"Ź", 211}, {"ź", 212}, {"Ż", 213}, {"ż", 214}, {"Ž", 215}, + {"ž", 216}, {"ƀ", 217}, {"Ɓ", 218}, {"Ƃ", 219}, {"ƃ", 220}, {"Ƅ", 221}, + {"ƅ", 222}, {"Ɔ", 223}, {"Ƈ", 224}, {"ƈ", 225}, {"Ɖ", 226}, {"Ɗ", 227}, + {"Ƌ", 228}, {"ƌ", 229}, {"ƍ", 230}, {"Ǝ", 231}, {"Ə", 232}, {"Ɛ", 233}, + {"Ƒ", 234}, {"ƒ", 235}, {"Ɠ", 236}, {"Ɣ", 237}, {"ƕ", 238}, {"Ɩ", 239}, + {"Ɨ", 240}, {"Ƙ", 241}, {"ƙ", 242}, {"ƚ", 243}, {"ƛ", 244}, {"Ɯ", 245}, + {"Ɲ", 246}, {"ƞ", 247}, {"Ɵ", 248}, {"Ơ", 249}, {"ơ", 250}, {"Ƣ", 251}, + {"ƣ", 252}, {"Ƥ", 253}, {"ƥ", 254}, {"Ʀ", 255}, {"⁇", 32}, + }; + + return table; +} diff --git a/sherpa-onnx/csrc/bbpe.h b/sherpa-onnx/csrc/bbpe.h new file mode 100644 index 000000000..8c317f83d --- /dev/null +++ b/sherpa-onnx/csrc/bbpe.h @@ -0,0 +1,15 @@ +// sherpa-onnx/csrc/bbpe.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_BBPE_H_ +#define SHERPA_ONNX_CSRC_BBPE_H_ +#include +#include + +// It is equivalent to the map BCHAR_TO_BYTE +// from +// https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py#L280 +const std::unordered_map &GetByteBpeTable(); + +#endif // SHERPA_ONNX_CSRC_BBPE_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 2721ecdf3..3dca0dfcc 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -41,7 +41,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, text.append(sym); if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { - // for byte bpe models + // for bpe models with byte_fallback // (but don't rewrite printable characters 0x20..0x7e, // which collide with standard BPE units) std::ostringstream os; @@ -52,6 +52,11 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, r.tokens.push_back(std::move(sym)); } + + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + r.text = std::move(text); float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 64f3798fa..158be5622 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -43,7 +43,7 @@ static OfflineRecognitionResult Convert( text.append(sym); if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { - // for byte bpe models, + // for bpe models with byte_fallback, // (but don't rewrite printable characters 0x20..0x7e, // which collide with standard BPE units) std::ostringstream os; @@ -54,6 +54,10 @@ static OfflineRecognitionResult Convert( r.tokens.push_back(std::move(sym)); } + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + r.text = std::move(text); float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 3560b1ab7..797d90f0c 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -34,13 +34,14 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.tokens.size()); + std::string text; for (auto i : src.tokens) { auto sym = sym_table[i]; - r.text.append(sym); + text.append(sym); if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { - // for byte bpe models + // for bpe models with byte_fallback // (but don't rewrite printable characters 0x20..0x7e, // which collide with standard BPE units) std::ostringstream os; @@ -52,6 +53,12 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, r.tokens.push_back(std::move(sym)); } + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + + r.text = std::move(text); + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; for (auto t : src.timestamps) { float time = frame_shift_s * t; diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 2eac3cf84..e6fd0b505 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -38,13 +38,14 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.tokens.size()); + std::string text; for (auto i : src.tokens) { auto sym = sym_table[i]; - r.text.append(sym); + text.append(sym); if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { - // for byte bpe models + // for bpe models with byte_fallback // (but don't rewrite printable characters 0x20..0x7e, // which collide with standard BPE units) std::ostringstream os; @@ -56,6 +57,12 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, r.tokens.push_back(std::move(sym)); } + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + + r.text = std::move(text); + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; for (auto t : src.timestamps) { float time = frame_shift_s * t; diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 3455d2007..e36bd5e58 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/csrc/symbol-table.h" #include +#include #include #include #include @@ -22,8 +23,10 @@ #endif #include "sherpa-onnx/csrc/base64-decode.h" +#include "sherpa-onnx/csrc/bbpe.h" #include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -47,6 +50,59 @@ inline void Trim(std::string *s, const char *t = ws) { TrimRight(s, t); TrimLeft(s, t); } + +bool IsByteBPE(const char *s, int32_t n) { + const uint8_t *p = reinterpret_cast(s); + if (n >= 3 && p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + return IsByteBPE(s + 3, n - 3); + } + + for (int32_t i = 0; i != n; ++i) { + if (p[i] > 0xc6) { + return false; + } + } + + return true; +} + +bool IsByteBPE(const std::unordered_map &sym2id) { + uint8_t max_v = 0; + for (const auto &p : sym2id) { + const auto &s = p.first; + if (!IsByteBPE(s.c_str(), s.size())) { + return false; + } + + uint8_t m = 0; + if (s.size() >= 3) { + const uint8_t *p = reinterpret_cast(s.c_str()); + + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + if (s.size() > 3) { + m = *std::max_element( + reinterpret_cast(s.data()) + 3, + reinterpret_cast(s.data()) + s.size()); + } else { + m = 0; + } + } else { + m = *std::max_element( + reinterpret_cast(s.data()), + reinterpret_cast(s.data()) + s.size()); + } + } else { + m = *std::max_element( + reinterpret_cast(s.data()), + reinterpret_cast(s.data()) + s.size()); + } + + max_v = (m > max_v) ? m : max_v; + } + + return static_cast(max_v) == 0xc6; +} + } // namespace std::unordered_map ReadTokens( @@ -111,7 +167,10 @@ SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) { Init(is); } -void SymbolTable::Init(std::istream &is) { sym2id_ = ReadTokens(is, &id2sym_); } +void SymbolTable::Init(std::istream &is) { + sym2id_ = ReadTokens(is, &id2sym_); + is_bbpe_ = IsByteBPE(sym2id_); +} std::string SymbolTable::ToString() const { std::ostringstream os; @@ -124,7 +183,7 @@ std::string SymbolTable::ToString() const { const std::string SymbolTable::operator[](int32_t id) const { std::string sym = id2sym_.at(id); - if (sym.size() >= 3) { + if (sym.size() >= 3 && !is_bbpe_) { // For BPE-based models, we replace ▁ with a space // Unicode 9601, hex 0x2581, utf8 0xe29681 const uint8_t *p = reinterpret_cast(sym.c_str()); @@ -133,7 +192,7 @@ const std::string SymbolTable::operator[](int32_t id) const { } } - // for byte-level BPE + // for BPE with byte_fallback // id 0 is blank, id 1 is sos/eos, id 2 is unk // // Note: For moonshine models, 0 is , 1, is , 2 is @@ -172,6 +231,33 @@ void SymbolTable::ApplyBase64Decode() { } } +std::string SymbolTable::DecodeByteBpe(const std::string &text) const { + if (!is_bbpe_) { + return text; + } + auto v = SplitUtf8(text); + + const auto &bbpe_table = GetByteBpeTable(); + std::string ans; + for (const auto &s : v) { + if (s == "▁") { + if (!ans.empty() && ans.back() != ' ' && std::isprint(ans.back())) { + ans.push_back(' '); + } + } else if (bbpe_table.count(s)) { + ans.push_back(bbpe_table.at(s)); + } else if (std::isprint(s[0])) { + ans.append(s); + } else { + // Should not happen + SHERPA_ONNX_LOGE("Skip OOV: %s from %s", s.c_str(), text.c_str()); + } + } + + // TODO(fangjun): Filter invalid utf-8 sequences + return ans; +} + #if __ANDROID_API__ >= 9 template SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename); diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 20d8d206b..5950d6008 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -56,12 +56,17 @@ class SymbolTable { int32_t NumSymbols() const { return id2sym_.size(); } + std::string DecodeByteBpe(const std::string &text) const; + + bool IsByteBpe() const { return is_bbpe_; } + private: void Init(std::istream &is); private: std::unordered_map sym2id_; std::unordered_map id2sym_; + bool is_bbpe_ = false; }; std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);