-
Notifications
You must be signed in to change notification settings - Fork 33
/
convert_msmarco_passages_doc_to_anserini.py
76 lines (67 loc) · 2.69 KB
/
convert_msmarco_passages_doc_to_anserini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
'''
Segment the documents and append their url, title, predicted queries to them.
Then, they are saved into json which can be used for indexing.
'''
import argparse
import gzip
import json
import os
import spacy
from tqdm import tqdm
def create_segments(doc_text, max_length, stride):
doc_text = doc_text.strip()
doc = nlp(doc_text[:10000])
sentences = [sent.string.strip() for sent in doc.sents]
segments = []
for i in range(0, len(sentences), stride):
segment = " ".join(sentences[i:i+max_length])
segments.append(segment)
if i + max_length >= len(sentences):
break
return segments
parser = argparse.ArgumentParser(
description='Concatenate MS MARCO original docs with predicted queries')
parser.add_argument('--original_docs_path', required=True, help='MS MARCO .tsv corpus file.')
parser.add_argument('--doc_ids_path', required=True, help='File mapping segments to doc ids.')
parser.add_argument('--output_docs_path', required=True, help='Output file in the anserini jsonl format.')
parser.add_argument('--predictions_path', default=None, help='File containing predicted queries.')
parser.add_argument('--max_length', default=10)
parser.add_argument('--stride', default=5)
args = parser.parse_args()
os.makedirs(os.path.dirname(args.output_docs_path), exist_ok=True)
f_corpus = gzip.open(args.original_docs_path, mode='rt')
f_out = open(args.output_docs_path, 'w')
max_length = args.max_length
stride = args.stride
nlp = spacy.blank("en")
nlp.add_pipe(nlp.create_pipe("sentencizer"))
print('Spliting documents...')
doc_id_ref = None
if args.predictions_path is None:
doc_ids_queries = zip(open(args.doc_ids_path))
else:
doc_ids_queries = zip(open(args.doc_ids_path), open(args.predictions_path))
for doc_id_query in tqdm(doc_ids_queries):
doc_id = doc_id_query[0].strip()
if doc_id != doc_id_ref:
f_doc_id, doc_url, doc_title, doc_text = next(f_corpus).split('\t')
while f_doc_id != doc_id:
f_doc_id, doc_url, doc_title, doc_text = next(f_corpus).split('\t')
segments = create_segments(doc_text, args.max_length, args.stride)
seg_id = 0
else:
seg_id += 1
doc_seg = f'{doc_id}#{seg_id}'
if seg_id < len(segments):
segment = segments[seg_id]
if args.predictions_path is None:
expanded_text = f'{doc_url} {doc_title} {segment}'
else:
predicted_queries_partial = doc_id_query[1]
expanded_text = f'{doc_url} {doc_title} {segment} {predicted_queries_partial}'
output_dict = {'id': doc_seg, 'contents': expanded_text}
f_out.write(json.dumps(output_dict) + '\n')
doc_id_ref = doc_id
f_corpus.close()
f_out.close()
print('Done!')