diff --git a/examples/text/information_extraction/ernie/cpp/uie.cc b/examples/text/information_extraction/ernie/cpp/uie.cc index 93c5adfeee..0e5a02f5f8 100644 --- a/examples/text/information_extraction/ernie/cpp/uie.cc +++ b/examples/text/information_extraction/ernie/cpp/uie.cc @@ -75,7 +75,7 @@ Schema::Schema( UIEModel::UIEModel(const std::string& model_file, const std::string& params_file, - const std::string& vocab_file, double position_prob, + const std::string& vocab_file, float position_prob, size_t max_length, const std::vector& schema) : max_length_(max_length), position_prob_(position_prob), @@ -90,7 +90,7 @@ UIEModel::UIEModel(const std::string& model_file, UIEModel::UIEModel( const std::string& model_file, const std::string& params_file, - const std::string& vocab_file, double position_prob, size_t max_length, + const std::string& vocab_file, float position_prob, size_t max_length, const std::unordered_map>& schema) : max_length_(max_length), position_prob_(position_prob), @@ -155,6 +155,20 @@ void UIEModel::AutoSplitter( } } +void UIEModel::GetCandidateIdx( + const float* probs, int64_t batch_size, int64_t seq_len, + std::vector>>* candidate_idx_prob, + float threshold) const { + for (int i = 0; i < batch_size; ++i) { + candidate_idx_prob->push_back({}); + for (int j = 0; j < seq_len; ++j) { + if (probs[i * seq_len + j] > threshold) { + candidate_idx_prob->back().push_back({j, probs[i * seq_len + j]}); + } + } + } +} + void UIEModel::PredictUIEInput(const std::vector& input_texts, const std::vector& prompts) { // 1. Shortten the input texts and prompts @@ -214,17 +228,20 @@ void UIEModel::PredictUIEInput(const std::vector& input_texts, inputs[i].name = runtime_.GetInputInfo(i).name; } - std::vector start_probs, end_probs; std::vector outputs(runtime_.NumOutputs()); // 4. Infer runtime_.Infer(inputs, &outputs); auto* start_prob = reinterpret_cast(outputs[0].Data()); auto* end_prob = reinterpret_cast(outputs[1].Data()); - start_probs.insert(start_probs.end(), start_prob, - start_prob + outputs[0].Numel()); - end_probs.insert(end_probs.end(), end_prob, end_prob + outputs[0].Numel()); // 5. Postprocess + std::vector>> start_candidate_idx_prob, + end_candidate_idx_prob; + + GetCandidateIdx(start_prob, outputs[0].shape[0], outputs[0].shape[1], + &start_candidate_idx_prob, position_prob_); + GetCandidateIdx(end_prob, outputs[1].shape[0], outputs[1].shape[1], + &end_candidate_idx_prob, position_prob_); } void UIEModel::Predict(const std::vector& texts, diff --git a/examples/text/information_extraction/ernie/cpp/uie.h b/examples/text/information_extraction/ernie/cpp/uie.h index 583aeeb25a..a638181246 100644 --- a/examples/text/information_extraction/ernie/cpp/uie.h +++ b/examples/text/information_extraction/ernie/cpp/uie.h @@ -70,11 +70,11 @@ struct UIEInput { struct UIEModel { UIEModel(const std::string& model_file, const std::string& params_file, - const std::string& vocab_file, double position_prob, + const std::string& vocab_file, float position_prob, size_t max_length, const std::vector& schema); UIEModel( const std::string& model_file, const std::string& params_file, - const std::string& vocab_file, double position_prob, size_t max_length, + const std::string& vocab_file, float position_prob, size_t max_length, const std::unordered_map>& schema); void SetSchema(const std::vector& schema); void SetSchema( @@ -90,10 +90,16 @@ struct UIEModel { const std::vector& texts, size_t max_length, std::vector* short_texts, std::unordered_map>* input_mapping); + // Get idx of the last dimension in probability arrays, which is greater than + // a limitation. + void GetCandidateIdx( + const float* probs, int64_t batch_size, int64_t seq_len, + std::vector>>* candidate_idx_prob, + float threshold = 0.5) const; fastdeploy::RuntimeOption runtime_option_; fastdeploy::Runtime runtime_; std::unique_ptr schema_; size_t max_length_; - double position_prob_; + float position_prob_; faster_tokenizer::tokenizers_impl::ErnieFasterTokenizer tokenizer_; };