diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 4ed2cb119..232412338 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -74,6 +74,8 @@ set(sources online-transducer-model-config.cc online-transducer-model.cc online-transducer-modified-beam-search-decoder.cc + online-transducer-nemo-model.cc + online-transducer-greedy-search-nemo-decoder.cc online-wenet-ctc-model-config.cc online-wenet-ctc-model.cc online-zipformer-transducer-model.cc diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 56da814f7..a6f398016 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -7,13 +7,28 @@ #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h" +#include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { std::unique_ptr OnlineRecognizerImpl::Create( const OnlineRecognizerConfig &config) { + if (!config.model_config.transducer.encoder.empty()) { - return std::make_unique(config); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + + auto decoder_model = ReadFile(config.model_config.transducer.decoder); + auto sess = std::make_unique(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + + size_t node_count = sess->GetOutputCount(); + + if (node_count == 1) { + return std::make_unique(config); + } else { + SHERPA_ONNX_LOGE("Running streaming Nemo transducer model"); + return std::make_unique(config); + } } if (!config.model_config.paraformer.encoder.empty()) { @@ -34,7 +49,18 @@ std::unique_ptr OnlineRecognizerImpl::Create( std::unique_ptr OnlineRecognizerImpl::Create( AAssetManager *mgr, const OnlineRecognizerConfig &config) { if (!config.model_config.transducer.encoder.empty()) { - return std::make_unique(mgr, config); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + + auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); + auto sess = std::make_unique(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + + size_t node_count = sess->GetOutputCount(); + + if (node_count == 1) { + return std::make_unique(mgr, config); + } else { + return std::make_unique(mgr, config); + } } if (!config.model_config.paraformer.encoder.empty()) { diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index e7c8fa7e5..8fa12d94f 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, r.timestamps.reserve(src.tokens.size()); for (auto i : src.tokens) { + if (i == -1) continue; auto sym = sym_table[i]; r.text.append(sym); diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h new file mode 100644 index 000000000..9193b292f --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -0,0 +1,267 @@ +// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +// +// Copyright (c) 2022-2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ + +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/transpose.h" +#include "sherpa-onnx/csrc/utils.h" + +namespace sherpa_onnx { + +// defined in ./online-recognizer-transducer-impl.h +// static may or may not be here? TODDOs +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, + int32_t subsampling_factor, + int32_t segment, + int32_t frames_since_start); + +class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerTransducerNeMoImpl( + const OnlineRecognizerConfig &config) + : config_(config), + symbol_table_(config.model_config.tokens), + endpoint_(config_.endpoint_config), + model_(std::make_unique( + config.model_config)) { + if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + PostInit(); + } + +#if __ANDROID_API__ >= 9 + explicit OnlineRecognizerTransducerNeMoImpl( + AAssetManager *mgr, const OnlineRecognizerConfig &config) + : config_(config), + symbol_table_(mgr, config.model_config.tokens), + endpoint_(mgrconfig_.endpoint_config), + model_(std::make_unique( + mgr, config.model_config)) { + if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + + PostInit(); + } +#endif + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + stream->SetStates(model_->GetInitStates()); + InitOnlineStream(stream.get()); + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + OnlineTransducerDecoderResult decoder_result = s->GetResult(); + decoder_->StripLeadingBlanks(&decoder_result); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 8; + return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 8 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + { + // segment is incremented only when the last + // result is not empty + const auto &r = s->GetResult(); + if (!r.tokens.empty() && r.tokens.back() != 0) { + s->GetCurrentSegment() += 1; + } + } + + // we keep the decoder_out + decoder_->UpdateDecoderOut(&s->GetResult()); + Ort::Value decoder_out = std::move(s->GetResult().decoder_out); + + auto r = decoder_->GetEmptyResult(); + + s->SetResult(r); + s->GetResult().decoder_out = std::move(decoder_out); + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = ss[0]->FeatureDim(); + + std::vector result(n); + std::vector features_vec(n * chunk_size * feature_dim); + std::vector> encoder_states(n); + + for (int32_t i = 0; i != n; ++i) { + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_size); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_size * feature_dim); + + result[i] = std::move(ss[i]->GetResult()); + encoder_states[i] = std::move(ss[i]->GetStates()); + + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{n, chunk_size, feature_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + // Batch size is 1 + auto states = std::move(encoder_states[0]); + int32_t num_states = states.size(); // num_states = 3 + auto t = model_->RunEncoder(std::move(x), std::move(states)); + // t[0] encoder_out, float tensor, (batch_size, dim, T) + // t[1] next states + + std::vector out_states; + out_states.reserve(num_states); + + for (int32_t k = 1; k != num_states + 1; ++k) { + out_states.push_back(std::move(t[k])); + } + + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); + + // defined in online-transducer-greedy-search-nemo-decoder.h + // get intial states of decoder. + std::vector &decoder_states = ss[0]->GetNeMoDecoderStates(); + + // Subsequent decoder states (for each chunks) are updated inside the Decode method. + // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it. + decoder_states = decoder_->Decode(std::move(encoder_out), + std::move(decoder_states), + &result, ss, n); + + ss[0]->SetResult(result[0]); + + ss[0]->SetStates(std::move(out_states)); + } + + void InitOnlineStream(OnlineStream *stream) const { + auto r = decoder_->GetEmptyResult(); + + stream->SetResult(r); + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); + } + + private: + void PostInit() { + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + config_.feat_config.low_freq = 0; + // config_.feat_config.high_freq = 8000; + config_.feat_config.is_librosa = true; + config_.feat_config.remove_dc_offset = false; + // config_.feat_config.window_type = "hann"; + config_.feat_config.dither = 0; + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + int32_t vocab_size = model_->VocabSize(); + + // check the blank ID + if (!symbol_table_.Contains("")) { + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token "); + exit(-1); + } + + if (symbol_table_[""] != vocab_size - 1) { + SHERPA_ONNX_LOGE(" is not the last token!"); + exit(-1); + } + + if (symbol_table_.NumSymbols() != vocab_size) { + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", + symbol_table_.NumSymbols(), vocab_size); + exit(-1); + } + + } + + private: + OnlineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; + Endpoint endpoint_; + +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 52cfb899f..62d93999e 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -90,6 +90,12 @@ class OnlineStream::Impl { std::vector &GetStates() { return states_; } + void SetNeMoDecoderStates(std::vector decoder_states) { + decoder_states_ = std::move(decoder_states); + } + + std::vector &GetNeMoDecoderStates() { return decoder_states_; } + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } std::vector &GetParaformerFeatCache() { @@ -129,6 +135,7 @@ class OnlineStream::Impl { TransducerKeywordResult empty_keyword_result_; OnlineCtcDecoderResult ctc_result_; std::vector states_; // states for transducer or ctc models + std::vector decoder_states_; // states for nemo transducer models std::vector paraformer_feat_cache_; std::vector paraformer_encoder_out_cache_; std::vector paraformer_alpha_cache_; @@ -218,6 +225,14 @@ std::vector &OnlineStream::GetStates() { return impl_->GetStates(); } +void OnlineStream::SetNeMoDecoderStates(std::vector decoder_states) { + return impl_->SetNeMoDecoderStates(std::move(decoder_states)); +} + +std::vector &OnlineStream::GetNeMoDecoderStates() { + return impl_->GetNeMoDecoderStates(); +} + const ContextGraphPtr &OnlineStream::GetContextGraph() const { return impl_->GetContextGraph(); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index 49b7f7402..4e444366e 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -91,6 +91,9 @@ class OnlineStream { void SetStates(std::vector states); std::vector &GetStates(); + void SetNeMoDecoderStates(std::vector decoder_states); + std::vector &GetNeMoDecoderStates(); + /** * Get the context graph corresponding to this stream. * diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc new file mode 100644 index 000000000..8f95215f7 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -0,0 +1,198 @@ +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static std::pair BuildDecoderInput( + int32_t token, OrtAllocator *allocator) { + std::array shape{1, 1}; + + Ort::Value decoder_input = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + std::array length_shape{1}; + Ort::Value decoder_input_length = Ort::Value::CreateTensor( + allocator, length_shape.data(), length_shape.size()); + + int32_t *p = decoder_input.GetTensorMutableData(); + + int32_t *p_length = decoder_input_length.GetTensorMutableData(); + + p[0] = token; + + p_length[0] = 1; + + return {std::move(decoder_input), std::move(decoder_input_length)}; +} + + +OnlineTransducerDecoderResult +OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const { + int32_t context_size = 8; + int32_t blank_id = 0; // always 0 + OnlineTransducerDecoderResult r; + r.tokens.resize(context_size, -1); + r.tokens.back() = blank_id; + + return r; +} + +static void UpdateCachedDecoderOut( + OrtAllocator *allocator, const Ort::Value *decoder_out, + std::vector *result) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + std::array v_shape{1, shape[1]}; + + const float *src = decoder_out->GetTensorData(); + for (auto &r : *result) { + if (!r.decoder_out) { + r.decoder_out = Ort::Value::CreateTensor(allocator, v_shape.data(), + v_shape.size()); + } + + float *dst = r.decoder_out.GetTensorMutableData(); + std::copy(src, src + shape[1], dst); + src += shape[1]; + } +} + +std::vector DecodeOne( + const float *encoder_out, int32_t num_rows, int32_t num_cols, + OnlineTransducerNeMoModel *model, float blank_penalty, + std::vector& decoder_states, + std::vector *result) { + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // OnlineTransducerDecoderResult result; + int32_t vocab_size = model->VocabSize(); + int32_t blank_id = vocab_size - 1; + + auto &r = (*result)[0]; + Ort::Value decoder_out{nullptr}; + + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); + // decoder_input_pair[0]: decoder_input + // decoder_input_pair[1]: decoder_input_length (discarded) + + // decoder_output_pair.second returns the next decoder state + std::pair> decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN + + std::array encoder_shape{1, num_cols, 1}; + + decoder_states = std::move(decoder_output_pair.second); + + // TODO: Inside this loop, I need to framewise decoding. + for (int32_t t = 0; t != num_rows; ++t) { + Ort::Value cur_encoder_out = Ort::Value::CreateTensor( + memory_info, const_cast(encoder_out) + t * num_cols, num_cols, + encoder_shape.data(), encoder_shape.size()); + + Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), + View(&decoder_output_pair.first)); + + float *p_logit = logit.GetTensorMutableData(); + if (blank_penalty > 0) { + p_logit[blank_id] -= blank_penalty; + } + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + SHERPA_ONNX_LOGE("y=%d", y); + if (y != blank_id) { + r.tokens.push_back(y); + r.timestamps.push_back(t + r.frame_offset); + + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); + + // last decoder state becomes the current state for the first chunk + decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_states)); + + // Update the decoder states for the next chunk + decoder_states = std::move(decoder_output_pair.second); + } + } + + decoder_out = std::move(decoder_output_pair.first); +// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result); + + // Update frame_offset + for (auto &r : *result) { + r.frame_offset += num_rows; + } + + return std::move(decoder_states); +} + + +std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( + Ort::Value encoder_out, + std::vector decoder_states, + std::vector *result, + OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { + + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + if (shape[0] != result->size()) { + SHERPA_ONNX_LOGE( + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", + static_cast(shape[0]), + static_cast(result->size())); + exit(-1); + } + + int32_t batch_size = static_cast(shape[0]); // bs = 1 + int32_t dim1 = static_cast(shape[1]); // 2 + int32_t dim2 = static_cast(shape[2]); // 512 + + // Define and initialize encoder_out_length + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + int64_t length_value = 1; + std::vector length_shape = {1}; + + Ort::Value encoder_out_length = Ort::Value::CreateTensor( + memory_info, &length_value, 1, length_shape.data(), length_shape.size() + ); + + const int64_t *p_length = encoder_out_length.GetTensorData(); + const float *p = encoder_out.GetTensorData(); + + // std::vector ans(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + const float *this_p = p + dim1 * dim2 * i; + int32_t this_len = p_length[i]; + + // outputs the decoder state from last chunk. + auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result); + // ans[i] = decode_result_pair.first; + decoder_states = std::move(last_decoder_states); + } + + return decoder_states; + +} + +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h new file mode 100644 index 000000000..d5a7a078c --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ + +#include +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" + +namespace sherpa_onnx { + +class OnlineTransducerGreedySearchNeMoDecoder { + public: + OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, + float blank_penalty) + : model_(model), + blank_penalty_(blank_penalty) {} + + OnlineTransducerDecoderResult GetEmptyResult() const; + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} + void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {} + + std::vector Decode( + Ort::Value encoder_out, + std::vector decoder_states, + std::vector *result, + OnlineStream **ss = nullptr, int32_t n = 0); + + private: + OnlineTransducerNeMoModel *model_; // Not owned + float blank_penalty_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ + diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc new file mode 100644 index 000000000..b054e3b72 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -0,0 +1,441 @@ +// sherpa-onnx/csrc/online-transducer-nemo-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/transpose.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +class OnlineTransducerNeMoModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } +#endif + + std::vector RunEncoder(Ort::Value features, + std::vector states) { + Ort::Value &cache_last_channel = states[0]; + Ort::Value &cache_last_time = states[1]; + Ort::Value &cache_last_channel_len = states[2]; + + int32_t batch_size = features.GetTensorTypeAndShapeInfo().GetShape()[0]; + + std::array length_shape{batch_size}; + + Ort::Value length = Ort::Value::CreateTensor( + allocator_, length_shape.data(), length_shape.size()); + + int64_t *p_length = length.GetTensorMutableData(); + + std::fill(p_length, p_length + batch_size, ChunkSize()); + + // (B, T, C) -> (B, C, T) + features = Transpose12(allocator_, &features); + + std::array inputs = { + std::move(features), View(&length), std::move(cache_last_channel), + std::move(cache_last_time), std::move(cache_last_channel_len)}; + + auto out = + encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); + // out[0]: logit + // out[1] logit_length + // out[2:] states_next + // + // we need to remove out[1] + + std::vector ans; + ans.reserve(out.size() - 1); + + for (int32_t i = 0; i != out.size(); ++i) { + if (i == 1) { + continue; + } + + ans.push_back(std::move(out[i])); + } + + return ans; + } + + std::pair> RunDecoder( + Ort::Value targets, std::vector states) { + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + // Create the tensor with a single int32_t value of 1 + int32_t length_value = 1; + std::vector length_shape = {1}; + + Ort::Value targets_length = Ort::Value::CreateTensor( + memory_info, &length_value, 1, length_shape.data(), length_shape.size() + ); + + std::vector decoder_inputs; + decoder_inputs.reserve(2 + states.size()); + + decoder_inputs.push_back(std::move(targets)); + decoder_inputs.push_back(std::move(targets_length)); + + for (auto &s : states) { + decoder_inputs.push_back(std::move(s)); + } + + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(), + decoder_inputs.size(), decoder_output_names_ptr_.data(), + decoder_output_names_ptr_.size()); + + std::vector states_next; + states_next.reserve(states.size()); + + // decoder_out[0]: decoder_output + // decoder_out[1]: decoder_output_length (discarded) + // decoder_out[2:] states_next + + for (int32_t i = 0; i != states.size(); ++i) { + states_next.push_back(std::move(decoder_out[i + 2])); + } + + // we discard decoder_out[1] + return {std::move(decoder_out[0]), std::move(states_next)}; + } + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), + joiner_input.size(), joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + + std::vector GetDecoderInitStates(int32_t batch_size) const { + std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + Ort::Value s0 = Ort::Value::CreateTensor(allocator_, s0_shape.data(), + s0_shape.size()); + + Fill(&s0, 0); + + std::array s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + + Ort::Value s1 = Ort::Value::CreateTensor(allocator_, s1_shape.data(), + s1_shape.size()); + + Fill(&s1, 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(s0)); + states.push_back(std::move(s1)); + + return states; + } + + int32_t ChunkSize() const { return window_size_; } + + int32_t ChunkShift() const { return chunk_shift_; } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + + int32_t VocabSize() const { return vocab_size_; } + + OrtAllocator *Allocator() const { return allocator_; } + + std::string FeatureNormalizationMethod() const { return normalize_type_; } + + // Return a vector containing 3 tensors + // - cache_last_channel + // - cache_last_time_ + // - cache_last_channel_len + std::vector GetInitStates() { + std::vector ans; + ans.reserve(3); + ans.push_back(View(&cache_last_channel_)); + ans.push_back(View(&cache_last_time_)); + ans.push_back(View(&cache_last_channel_len_)); + + return ans; + } + +private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + + // need to increase by 1 since the blank token is not included in computing + // vocab_size in NeMo. + vocab_size_ += 1; + + SHERPA_ONNX_READ_META_DATA(window_size_, "window_size"); + SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); + + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_, + "cache_last_channel_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_, + "cache_last_channel_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_, + "cache_last_channel_dim3"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3"); + + if (normalize_type_ == "NA") { + normalize_type_ = ""; + } + + InitStates(); + } + + void InitStates() { + std::array cache_last_channel_shape{1, cache_last_channel_dim1_, + cache_last_channel_dim2_, + cache_last_channel_dim3_}; + + cache_last_channel_ = Ort::Value::CreateTensor( + allocator_, cache_last_channel_shape.data(), + cache_last_channel_shape.size()); + + Fill(&cache_last_channel_, 0); + + std::array cache_last_time_shape{ + 1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_}; + + cache_last_time_ = Ort::Value::CreateTensor( + allocator_, cache_last_time_shape.data(), cache_last_time_shape.size()); + + Fill(&cache_last_time_, 0); + + int64_t shape = 1; + cache_last_channel_len_ = + Ort::Value::CreateTensor(allocator_, &shape, 1); + + cache_last_channel_len_.GetTensorMutableData()[0] = 0; + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + void InitJoiner(void *model_data, size_t model_data_length) { + joiner_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + } + + private: + OnlineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + int32_t window_size_; + int32_t chunk_shift_; + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 8; + std::string normalize_type_; + int32_t pred_rnn_layers_ = -1; + int32_t pred_hidden_ = -1; + + int32_t cache_last_channel_dim1_; + int32_t cache_last_channel_dim2_; + int32_t cache_last_channel_dim3_; + int32_t cache_last_time_dim1_; + int32_t cache_last_time_dim2_; + int32_t cache_last_time_dim3_; + + Ort::Value cache_last_channel_{nullptr}; + Ort::Value cache_last_time_{nullptr}; + Ort::Value cache_last_channel_len_{nullptr}; +}; + +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + AAssetManager *mgr, const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; + +std::vector +OnlineTransducerNeMoModel::RunEncoder(Ort::Value features, + std::vector states) const { + return impl_->RunEncoder(std::move(features), std::move(states)); +} + +std::pair> +OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets, + std::vector states) const { + return impl_->RunDecoder(std::move(targets), std::move(states)); +} + +std::vector OnlineTransducerNeMoModel::GetDecoderInitStates( + int32_t batch_size) const { + return impl_->GetDecoderInitStates(batch_size); +} + +Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) const { + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); +} + + +int32_t OnlineTransducerNeMoModel::ChunkSize() const { + return impl_->ChunkSize(); + } + +int32_t OnlineTransducerNeMoModel::ChunkShift() const { + return impl_->ChunkShift(); + } + +int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +int32_t OnlineTransducerNeMoModel::VocabSize() const { + return impl_->VocabSize(); +} + +OrtAllocator *OnlineTransducerNeMoModel::Allocator() const { + return impl_->Allocator(); +} + +std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { + return impl_->FeatureNormalizationMethod(); +} + +std::vector OnlineTransducerNeMoModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h new file mode 100644 index 000000000..97a632f50 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -0,0 +1,124 @@ +// sherpa-onnx/csrc/online-transducer-nemo-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +// see +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40 +// Its decoder is stateful, not stateless. +class OnlineTransducerNeMoModel { + public: + explicit OnlineTransducerNeMoModel(const OnlineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineTransducerNeMoModel(AAssetManager *mgr, + const OnlineModelConfig &config); +#endif + + ~OnlineTransducerNeMoModel(); + // A list of 3 tensors: + // - cache_last_channel + // - cache_last_time + // - cache_last_channel_len + std::vector GetInitStates() const; + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a tuple containing: + * - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim) + * - ans[1:]: contains next states + */ + std::vector RunEncoder( + Ort::Value features, std::vector states) const; // NOLINT + + /** Run the decoder network. + * + * @param targets A int32 tensor of shape (batch_size, 1) + * @param states The states for the decoder model. + * @return Return a vector: + * - ans[0] is the decoder_out (a float tensor) + * - ans[1:] is the next states + */ + std::pair> RunDecoder( + Ort::Value targets, std::vector states) const; + + std::vector GetDecoderInitStates(int32_t batch_size) const; + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. + * @param decoder_out Output of the decoder network. + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. + */ + Ort::Value RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) const; + + + /** We send this number of feature frames to the encoder at a time. */ + int32_t ChunkSize() const; + + /** Number of input frames to discard after each call to RunEncoder. + * + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6. + * + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8. + * Then we discard frame 0~5 since chunk_shift is 6. + * In the second call of RunEncoder, we use frames 6~13; and then we discard + * frames 6~11. + * In the third call of RunEncoder, we use frames 12~19; and then we discard + * frames 12~16. + * + * Note: ChunkSize() - ChunkShift() == right context size + */ + int32_t ChunkShift() const; + + /** Return the subsampling factor of the model. + */ + int32_t SubsamplingFactor() const; + + int32_t VocabSize() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string FeatureNormalizationMethod() const; + + private: + class Impl; + std::unique_ptr impl_; + }; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_