Skip to content

Commit

Permalink
Add relation schema
Browse files Browse the repository at this point in the history
  • Loading branch information
joey12300 committed Aug 22, 2022
1 parent 48b3847 commit c9c3b69
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 24 deletions.
14 changes: 14 additions & 0 deletions examples/text/information_extraction/ernie/cpp/infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,19 @@ int main() {
predictor.Predict({"2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷"
"爱凌以188.25分获得金牌!"},
&results);
std::cout << results << std::endl;

// schema for relation extraction
// schema = {'竞赛名称': ['主办方', '承办方', '已举办次数']}
predictor.SetSchema({{"竞赛名称",
{SchemaNode("主办方"), SchemaNode("承办方"),
SchemaNode("已举办次数")}}});
results.clear();
predictor.Predict(
{"2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度"
"公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会"
"承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。"},
&results);
std::cout << results << std::endl;
return 0;
}
104 changes: 85 additions & 19 deletions examples/text/information_extraction/ernie/cpp/uie.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <codecvt>
#include <locale>
#include <queue>
#include <sstream>

#include "utils/utf8.h" // faster_tokenizer helper funciton

Expand Down Expand Up @@ -63,10 +64,72 @@ static void CharToByteOffsetMap(const std::string& seq,
offset_mapping->push_back(index);
}

static std::ostream& PrintResult(std::ostream& os, const UIEResult& result,
int tab_size) {
constexpr int TAB_OFFSET = 4;
// Print text
for (int i = 0; i < tab_size; ++i) {
os << " ";
}
os << "text: " << result.text_ << "\n";

// Print probability
for (int i = 0; i < tab_size; ++i) {
os << " ";
}
os << "probability: " << result.probability_ << "\n";

// Print start
for (int i = 0; i < tab_size; ++i) {
os << " ";
}
os << "start: " << result.start_ << "\n";

// Print end
for (int i = 0; i < tab_size; ++i) {
os << " ";
}
os << "end: " << result.end_ << "\n";

// Print relation
if (result.relation_.size() > 0) {
for (int i = 0; i < tab_size; ++i) {
os << " ";
}
os << "relation:\n";
for (auto&& curr_relation : result.relation_) {
for (int i = 0; i < tab_size + TAB_OFFSET; ++i) {
os << " ";
}
os << curr_relation.first << ":\n";
for (int i = 0; i < curr_relation.second.size(); ++i) {
PrintResult(os, curr_relation.second[i],
tab_size + TAB_OFFSET + TAB_OFFSET);
}
}
}
os << "\n";
return os;
}

std::ostream& operator<<(std::ostream& os, const UIEResult& result) {
os << "text = " << result.text_ << "\nprobability = " << result.probability_
<< "\nstart = " << result.start_ << "\nend = " << result.end_;
os << std::endl;
return PrintResult(os, result, 0);
}

std::ostream& operator<<(
std::ostream& os,
const std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>&
results) {
os << "The result:\n";
for (int i = 0; i < results.size(); ++i) {
for (auto&& curr_result : results[i]) {
os << curr_result.first << ": \n";
for (auto&& uie_result : curr_result.second) {
PrintResult(os, uie_result, 4);
}
}
os << std::endl;
}
return os;
}

Expand All @@ -88,7 +151,7 @@ Schema::Schema(const std::vector<std::string>& schema_list,
}

Schema::Schema(
const std::unordered_map<std::string, std::vector<std::string>>& schema_map,
const std::unordered_map<std::string, std::vector<SchemaNode>>& schema_map,
const std::string& name) {
CreateRoot(name);
for (auto& schema_item : schema_map) {
Expand All @@ -114,7 +177,7 @@ UIEModel::UIEModel(const std::string& model_file,
UIEModel::UIEModel(
const std::string& model_file, const std::string& params_file,
const std::string& vocab_file, float position_prob, size_t max_length,
const std::unordered_map<std::string, std::vector<std::string>>& schema)
const std::unordered_map<std::string, std::vector<SchemaNode>>& schema)
: max_length_(max_length),
position_prob_(position_prob),
tokenizer_(vocab_file) {
Expand All @@ -131,7 +194,7 @@ void UIEModel::SetSchema(const std::vector<std::string>& schema) {
}

void UIEModel::SetSchema(
const std::unordered_map<std::string, std::vector<std::string>>& schema) {
const std::unordered_map<std::string, std::vector<SchemaNode>>& schema) {
schema_ = fastdeploy::utils::make_unique<Schema>(schema);
}

Expand All @@ -142,7 +205,8 @@ void UIEModel::AutoSplitter(
size_t cnt_org = 0;
size_t cnt_short = 0;
for (auto& text : texts) {
auto text_len = text.length();
auto text_len = faster_tokenizer::utils::GetUnicodeLenFromUTF8(
text.c_str(), text.length());
if (text_len <= max_length) {
short_texts->push_back(text);
if (input_mapping->count(cnt_org) == 0) {
Expand All @@ -152,12 +216,17 @@ void UIEModel::AutoSplitter(
}
cnt_short += 1;
} else {
std::vector<uint32_t> offset_mapping;
CharToByteOffsetMap(text, &offset_mapping);
for (size_t start = 0; start < text_len; start += max_length) {
size_t end = start + max_length;
if (end > text_len) {
end = text_len;
}
short_texts->emplace_back(text.data() + start, end - start);
auto unicode_start = offset_mapping[start];
auto unicode_end = offset_mapping[end];
short_texts->emplace_back(text.data() + unicode_start,
unicode_end - unicode_start);
}
auto short_idx = cnt_short;
cnt_short += text_len / max_length;
Expand Down Expand Up @@ -391,7 +460,6 @@ void UIEModel::PredictUIEInput(const std::vector<std::string>& input_texts,
// 2. Tokenize the short texts and short prompts
std::vector<faster_tokenizer::core::Encoding> encodings;
tokenizer_.EncodeBatchStrings(text_pair_input, &encodings);

// 3. Construct the input vector tensor
// 3.1 Convert encodings to input_ids, token_type_ids, position_ids, attn_mask
std::vector<int64_t> input_ids, token_type_ids, position_ids, attn_mask;
Expand Down Expand Up @@ -461,8 +529,9 @@ void UIEModel::Predict(
for (auto& node : schema_->root_->children_) {
nodes.push(node);
}
results->resize(texts.size());
while (!nodes.empty()) {
auto& node = nodes.front();
auto node = nodes.front();
nodes.pop();
std::vector<std::vector<size_t>> input_mapping;
size_t idx = 0;
Expand Down Expand Up @@ -498,7 +567,6 @@ void UIEModel::Predict(
// 2. Predict from UIEInput
std::vector<std::vector<UIEResult>> results_list;
PredictUIEInput(input_texts, prompts, &results_list);

// 3. Postprocess
std::vector<std::vector<UIEResult*>> relations;
relations.resize(texts.size());
Expand All @@ -511,7 +579,7 @@ void UIEModel::Predict(
continue;
}
if (curr_result.count(node.name_) == 0) {
curr_result[node.name_] = std::move(results_list[idx]);
curr_result[node.name_] = results_list[idx];
} else {
curr_result[node.name_].insert(curr_result[node.name_].end(),
results_list[idx].begin(),
Expand All @@ -534,8 +602,7 @@ void UIEModel::Predict(
continue;
}
if (new_relations[i][j]->relation_.count(node.name_) == 0) {
new_relations[i][j]->relation_[node.name_] =
std::move(results_list[idx]);
new_relations[i][j]->relation_[node.name_] = results_list[idx];
} else {
auto& curr_result = new_relations[i][j]->relation_[node.name_];
curr_result.insert(curr_result.end(), results_list[idx].begin(),
Expand All @@ -554,20 +621,19 @@ void UIEModel::Predict(
}
}
}

std::vector<std::vector<std::string>> prefix(texts.size());
for (int i = 0; i < input_mapping.size(); ++i) {
auto&& input_mapping_item = input_mapping[i];
for (auto&& idx : input_mapping_item) {
for (int j = 0; j < results_list[idx].size(); ++j) {
prefix[i].push_back(results_list[idx][j].text_ + "\u7684");
auto prefix_str = results_list[idx][j].text_ + "\xe7\x9a\x84";
prefix[i].push_back(prefix_str);
}
}
}

for (auto& node_child : node.children_) {
node_child.relations_ = std::move(relations);
node_child.prefix_ = std::move(prefix);
node_child.relations_ = relations;
node_child.prefix_ = prefix;
nodes.push(node_child);
}
}
Expand Down
22 changes: 17 additions & 5 deletions examples/text/information_extraction/ernie/cpp/uie.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,20 @@ struct UIEResult {
};

std::ostream& operator<<(std::ostream& os, const UIEResult& result);
std::ostream& operator<<(
std::ostream& os,
const std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>&
results);

struct SchemaNode {
std::string name_;
std::vector<std::vector<std::string>> prefix_;
std::vector<std::vector<UIEResult*>> relations_;
std::vector<SchemaNode> children_;

explicit SchemaNode(const std::string& name) : name_(name) {}
explicit SchemaNode(const std::string& name,
const std::vector<SchemaNode>& children = {})
: name_(name), children_(children) {}
void AddChild(const std::string& schema) { children_.emplace_back(schema); }
void AddChild(const std::string& schema,
const std::vector<std::string>& children) {
Expand All @@ -53,14 +59,20 @@ struct SchemaNode {
}
children_.emplace_back(schema_node);
}
void AddChild(const std::string& schema,
const std::vector<SchemaNode>& children) {
SchemaNode schema_node(schema);
schema_node.children_ = children;
children_.emplace_back(schema_node);
}
};

struct Schema {
explicit Schema(const std::string& schema, const std::string& name = "root");
explicit Schema(const std::vector<std::string>& schema_list,
const std::string& name = "root");
explicit Schema(const std::unordered_map<
std::string, std::vector<std::string>>& schema_map,
explicit Schema(const std::unordered_map<std::string,
std::vector<SchemaNode>>& schema_map,
const std::string& name = "root");

private:
Expand All @@ -77,10 +89,10 @@ struct UIEModel {
UIEModel(
const std::string& model_file, const std::string& params_file,
const std::string& vocab_file, float position_prob, size_t max_length,
const std::unordered_map<std::string, std::vector<std::string>>& schema);
const std::unordered_map<std::string, std::vector<SchemaNode>>& schema);
void SetSchema(const std::vector<std::string>& schema);
void SetSchema(
const std::unordered_map<std::string, std::vector<std::string>>& schema);
const std::unordered_map<std::string, std::vector<SchemaNode>>& schema);

void PredictUIEInput(const std::vector<std::string>& input_texts,
const std::vector<std::string>& prompts,
Expand Down

0 comments on commit c9c3b69

Please sign in to comment.