-
Notifications
You must be signed in to change notification settings - Fork 11
/
generate_summary.py
52 lines (46 loc) · 1.86 KB
/
generate_summary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from fairseq.models.bart import BARTModel
import argparse
parse = argparse.ArgumentParser()
parse.add_argument('--model_path', type=str)
parse.add_argument('--data_path', type=str)
parse.add_argument('--source_text_path', type=str)
parse.add_argument('--output_path', type=str)
parse.add_argument('--beam', type=int, default=4)
parse.add_argument('--lenpen', type=int, default=1)
parse.add_argument('--sentence_piece_model', type=str, default='sentence_piece_multilingual.model')
parse.add_argument('--max_len_b', type=int, default=200)
parse.add_argument('--min_len', type=int, default=3)
parse.add_argument('--no_repeat_ngram_size', type=int, default=3)
args = parse.parse_args()
bart = BARTModel.from_pretrained(
'.',
checkpoint_file=args.model_path,
data_name_or_path=args.data_path,
bpe='sentencepiece',
sentencepiece_vocab=args.sentence_piece_model,
task='translation',
)
bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open(args.source_text_path) as source, open(args.output_path, 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, beam=args.beam, lenpen=args.lenpen, max_len_b=args.max_len_b, min_len=args.min_len, no_repeat_ngram_size=args.no_repeat_ngram_size)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, beam=args.beam, lenpen=args.lenpen, max_len_b=args.max_len_b, min_len=args.min_len, no_repeat_ngram_size=args.no_repeat_ngram_size)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()