-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo.py
69 lines (63 loc) · 2.2 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import time
import argparse
import torch
import msgpack
from model.model import DocReaderModel
from model.utils import str2bool
from prepro import annotate, to_id, init
from train import BatchGen
parser = argparse.ArgumentParser(
description='Interact with document reader model.'
)
parser.add_argument('--model-file', default='models/best_model.pt',
help='path to model file')
parser.add_argument("--cuda", type=str2bool, nargs='?',
const=True, default=torch.cuda.is_available(),
help='whether to use GPU acceleration.')
args = parser.parse_args()
if args.cuda:
checkpoint = torch.load(args.model_file)
else:
checkpoint = torch.load(args.model_file, map_location=lambda storage, loc: storage)
state_dict = checkpoint['state_dict']
opt = checkpoint['config']
with open('SQuAD/meta.msgpack', 'rb') as f:
meta = msgpack.load(f, encoding='utf8')
embedding = torch.Tensor(meta['embedding'])
opt['pretrained_words'] = True
opt['vocab_size'] = embedding.size(0)
opt['embedding_dim'] = embedding.size(1)
opt['pos_size'] = len(meta['vocab_tag'])
opt['ner_size'] = len(meta['vocab_ent'])
opt['cuda'] = args.cuda
BatchGen.pos_size = opt['pos_size']
BatchGen.ner_size = opt['ner_size']
model = DocReaderModel(opt, embedding, state_dict)
if args.cuda:
model.cuda()
w2id = {w: i for i, w in enumerate(meta['vocab'])}
tag2id = {w: i for i, w in enumerate(meta['vocab_tag'])}
ent2id = {w: i for i, w in enumerate(meta['vocab_ent'])}
init()
while True:
id_ = 0
try:
while True:
text = input('Text: ')
if text.strip():
break
while True:
question = input('Question: ')
if question.strip():
break
except EOFError:
break
id_ += 1
start_time = time.time()
annotated = annotate(('interact-{}'.format(id_), text, question), meta['wv_cased'])
model_in = to_id(annotated, w2id, tag2id, ent2id)
model_in = next(iter(BatchGen([model_in], batch_size=1, gpu=args.cuda, evaluation=True)))
prediction = model.predict(model_in)[0]
end_time = time.time()
print('Answer: {}'.format(prediction))
print('Time: {:.4f}s'.format(end_time - start_time))