-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict.py
155 lines (141 loc) · 6.92 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import json
import torch
import numpy as np
from collections import namedtuple
from model import BertNer, BertRe
from seqeval.metrics.sequence_labeling import get_entities
from transformers import BertTokenizer
def get_args(args_path, args_name=None):
with open(args_path, "r") as fp:
args_dict = json.load(fp)
# 注意args不可被修改了
args = namedtuple(args_name, args_dict.keys())(*args_dict.values())
return args
class DgrePredictor:
def __init__(self, data_name):
self.data_name = data_name
self.ner_args = get_args(os.path.join("./checkpoint/{}/".format(data_name), "ner_args.json"), "ner_args")
self.re_args = get_args(os.path.join("./checkpoint/{}/".format(data_name), "re_args.json"), "re_args")
self.ner_id2label = {int(k): v for k, v in self.ner_args.id2label.items()}
self.tokenizer = BertTokenizer.from_pretrained(self.ner_args.bert_dir)
self.max_seq_len = self.ner_args.max_seq_len
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.ner_model = BertNer(self.ner_args)
self.ner_model.load_state_dict(torch.load(os.path.join(self.ner_args.output_dir, "pytorch_model_ner.bin")))
self.ner_model.to(self.device)
self.re_model = BertRe(self.re_args)
self.re_model.load_state_dict(torch.load(os.path.join(self.re_args.output_dir, "pytorch_model_re.bin")))
self.re_model.to(self.device)
self.sentiment = ["正面", "中立", "负面"]
def ner_tokenizer(self, text):
# print("文本长度需要小于:{}".format(self.max_seq_len))
text = text[:self.max_seq_len - 2]
text = ["[CLS]"] + [i for i in text] + ["[SEP]"]
print(text)
tmp_input_ids = self.tokenizer.convert_tokens_to_ids(text)
input_ids = tmp_input_ids + [0] * (self.max_seq_len - len(tmp_input_ids))
attention_mask = [1] * len(tmp_input_ids) + [0] * (self.max_seq_len - len(tmp_input_ids))
input_ids = torch.tensor(np.array([input_ids]))
attention_mask = torch.tensor(np.array([attention_mask]))
return input_ids, attention_mask
def re_tokenizer(self, text, aspect, prompt):
# print("文本长度需要小于:{}".format(self.max_seq_len))
pre_length = 3 + len(aspect) + len(prompt)
text = text[:self.max_seq_len - pre_length]
text = list(text)
aspect = list(aspect)
prompt = list(prompt)
tmp_input_ids = ["[CLS]"] + aspect + prompt + ["[SEP]"] + text + ["[SEP]"]
tmp_input_ids = self.tokenizer.convert_tokens_to_ids(tmp_input_ids)
input_ids = tmp_input_ids + [0] * (self.max_seq_len - len(tmp_input_ids))
attention_mask = [1] * len(tmp_input_ids) + [0] * (self.max_seq_len - len(tmp_input_ids))
token_type_ids = [0] * self.max_seq_len
input_ids = torch.tensor(np.array([input_ids]))
token_type_ids = torch.tensor(np.array([token_type_ids]))
attention_mask = torch.tensor(np.array([attention_mask]))
return input_ids, attention_mask, token_type_ids
def re_predict(self, text, ner_result, prompt="怎么样?"):
res = []
for k, v in ner_result.items():
if k not in self.sentiment:
for aspect in v:
aspect = aspect[0]
input_ids, attention_mask, token_type_ids = self.re_tokenizer(text, aspect, prompt)
input_ids = input_ids.to(self.device)
token_type_ids = token_type_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
output = self.re_model(input_ids, token_type_ids, attention_mask)
start_logits = output.start_logits
end_logits = output.end_logits
start_logits = start_logits.detach().cpu().numpy()
end_logits = end_logits.detach().cpu().numpy()
start_logits = np.argmax(start_logits, -1)
end_logits = np.argmax(end_logits, -1)
ind = 2 + len(aspect) + len(prompt)
start_logits = start_logits[0]
end_logits = end_logits[0]
start_logits = start_logits[ind:]
end_logits = end_logits[ind:]
for_flag = False
for i, s in enumerate(start_logits):
for j, e in enumerate(end_logits):
if s == e and s == 1:
opinion = text[i:j + 1]
sentiment = self.opinion2sen_dict[opinion] if opinion in self.opinion2sen_dict else ""
res.append((aspect, text[i:j + 1], sentiment))
for_flag = True
break
if for_flag:
break
return res
def ner_predict(self, text):
input_ids, attention_mask = self.ner_tokenizer(text)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
output = self.ner_model(input_ids, attention_mask)
attention_mask = attention_mask.detach().cpu().numpy()
length = sum(attention_mask[0])
logits = output.logits
logits = logits[0][1:length - 1]
logits = [self.ner_id2label[i] for i in logits]
entities = get_entities(logits)
result = {}
for ent in entities:
ent_name = ent[0]
ent_start = ent[1]
ent_end = ent[2]
if ent_name not in result:
result[ent_name] = [("".join(text[ent_start:ent_end + 1]), ent_start, ent_end)]
else:
result[ent_name].append(("".join(text[ent_start:ent_end + 1]), ent_start, ent_end))
return result
def opinion2sen(self, ner_result):
self.opinion2sen_dict = {}
for k, v in ner_result.items():
if k in self.sentiment:
for s in v:
self.opinion2sen_dict[s[0]] = k
return self.opinion2sen_dict
if __name__ == "__main__":
data_name = "gdcq"
dgrePredictor = DgrePredictor(data_name)
# texts = [
# "很好,遮暇功能差一些,总体还不错",
# "包装太随便了,连个包装盒都没有,第一感觉很不好",
# "宝贝收到了,产品非常的不好,简直就是个垃圾,我都扔了。",
# ]
with open("./data/gdcq/re_data/dev.txt", "r") as fp:
data = fp.read().strip().split("\n")
for i, d in enumerate(data):
d = eval(d)
text = "".join(d["text"])
ner_result = dgrePredictor.ner_predict(text)
dgrePredictor.opinion2sen(ner_result)
re_result = dgrePredictor.re_predict(text, ner_result)
print("文本>>>>>", text)
print("实体>>>>>", ner_result)
print("关系>>>>>", re_result)
print("=" * 100)
if i > 10:
break