Skip to content

Commit

Permalink
Add GetCandidateIdx
Browse files Browse the repository at this point in the history
  • Loading branch information
joey12300 committed Aug 16, 2022
1 parent 5434758 commit f5537a9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
29 changes: 23 additions & 6 deletions examples/text/information_extraction/ernie/cpp/uie.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Schema::Schema(

UIEModel::UIEModel(const std::string& model_file,
const std::string& params_file,
const std::string& vocab_file, double position_prob,
const std::string& vocab_file, float position_prob,
size_t max_length, const std::vector<std::string>& schema)
: max_length_(max_length),
position_prob_(position_prob),
Expand All @@ -90,7 +90,7 @@ UIEModel::UIEModel(const std::string& model_file,

UIEModel::UIEModel(
const std::string& model_file, const std::string& params_file,
const std::string& vocab_file, double position_prob, size_t max_length,
const std::string& vocab_file, float position_prob, size_t max_length,
const std::unordered_map<std::string, std::vector<std::string>>& schema)
: max_length_(max_length),
position_prob_(position_prob),
Expand Down Expand Up @@ -155,6 +155,20 @@ void UIEModel::AutoSplitter(
}
}

void UIEModel::GetCandidateIdx(
const float* probs, int64_t batch_size, int64_t seq_len,
std::vector<std::vector<std::pair<int64_t, float>>>* candidate_idx_prob,
float threshold) const {
for (int i = 0; i < batch_size; ++i) {
candidate_idx_prob->push_back({});
for (int j = 0; j < seq_len; ++j) {
if (probs[i * seq_len + j] > threshold) {
candidate_idx_prob->back().push_back({j, probs[i * seq_len + j]});
}
}
}
}

void UIEModel::PredictUIEInput(const std::vector<std::string>& input_texts,
const std::vector<std::string>& prompts) {
// 1. Shortten the input texts and prompts
Expand Down Expand Up @@ -214,17 +228,20 @@ void UIEModel::PredictUIEInput(const std::vector<std::string>& input_texts,
inputs[i].name = runtime_.GetInputInfo(i).name;
}

std::vector<float> start_probs, end_probs;
std::vector<fastdeploy::FDTensor> outputs(runtime_.NumOutputs());
// 4. Infer
runtime_.Infer(inputs, &outputs);
auto* start_prob = reinterpret_cast<float*>(outputs[0].Data());
auto* end_prob = reinterpret_cast<float*>(outputs[1].Data());
start_probs.insert(start_probs.end(), start_prob,
start_prob + outputs[0].Numel());
end_probs.insert(end_probs.end(), end_prob, end_prob + outputs[0].Numel());

// 5. Postprocess
std::vector<std::vector<std::pair<int64_t, float>>> start_candidate_idx_prob,
end_candidate_idx_prob;

GetCandidateIdx(start_prob, outputs[0].shape[0], outputs[0].shape[1],
&start_candidate_idx_prob, position_prob_);
GetCandidateIdx(end_prob, outputs[1].shape[0], outputs[1].shape[1],
&end_candidate_idx_prob, position_prob_);
}

void UIEModel::Predict(const std::vector<std::string>& texts,
Expand Down
12 changes: 9 additions & 3 deletions examples/text/information_extraction/ernie/cpp/uie.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ struct UIEInput {

struct UIEModel {
UIEModel(const std::string& model_file, const std::string& params_file,
const std::string& vocab_file, double position_prob,
const std::string& vocab_file, float position_prob,
size_t max_length, const std::vector<std::string>& schema);
UIEModel(
const std::string& model_file, const std::string& params_file,
const std::string& vocab_file, double position_prob, size_t max_length,
const std::string& vocab_file, float position_prob, size_t max_length,
const std::unordered_map<std::string, std::vector<std::string>>& schema);
void SetSchema(const std::vector<std::string>& schema);
void SetSchema(
Expand All @@ -90,10 +90,16 @@ struct UIEModel {
const std::vector<std::string>& texts, size_t max_length,
std::vector<std::string>* short_texts,
std::unordered_map<size_t, std::vector<size_t>>* input_mapping);
// Get idx of the last dimension in probability arrays, which is greater than
// a limitation.
void GetCandidateIdx(
const float* probs, int64_t batch_size, int64_t seq_len,
std::vector<std::vector<std::pair<int64_t, float>>>* candidate_idx_prob,
float threshold = 0.5) const;
fastdeploy::RuntimeOption runtime_option_;
fastdeploy::Runtime runtime_;
std::unique_ptr<Schema> schema_;
size_t max_length_;
double position_prob_;
float position_prob_;
faster_tokenizer::tokenizers_impl::ErnieFasterTokenizer tokenizer_;
};

0 comments on commit f5537a9

Please sign in to comment.