-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.py
133 lines (107 loc) · 3.99 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import json
import os
import numpy as np
from tqdm import tqdm
"""
THIS FILE MAINLY PROVIDES HELPER FUNCTIONS TO OTHER FUNCTIONS
"""
def readRankFile(CONF, SCONF):
# Read the ranked file from local, which is the run.msmarco-passage.dev.small.tsv file
rankFilePath = SCONF["RANK_FILE"]
with open(rankFilePath, 'r') as f:
file = f.readlines()
# Clean the collection by removing the trailing \n and split by \t
cleanedCol = []
for each in file:
temp = each.replace("\n", "")
if CONF["METHOD"] == "PASS":
temp = temp.split("\t")
cleanedCol.append(temp)
else:
temp = temp.split("\t")
newtemp = [temp[0], temp[1], temp[2]]
cleanedCol.append(newtemp)
fullCollection = []
tempCollection = []
lastCollection = []
currentID = cleanedCol[0][0]
# Separate the collection by different query ids, this will result in a 6980 length list
# with each contains the document ids that it retrieved (normally 1000 per query, but some are less)
for element in cleanedCol:
if element[0] == currentID:
tempCollection.append(element)
else:
currentID = element[0]
fullCollection.append(np.array(tempCollection))
tempCollection = [element]
if element[0] == cleanedCol[-1][0]:
lastCollection.append(element)
fullCollection.append(np.array(lastCollection))
return fullCollection
def readQueryFile(path):
# Read the query file into memory and construct a dictionary
queryCollection = []
queryDict = {}
queryFilePath = path
with open(queryFilePath, 'r') as f:
contents = f.readlines()
for line in contents:
queryContent = json.loads(line)
queryCollection.append(queryContent)
for query in queryCollection:
queryDict[query["id"]] = query["contents"]
return queryDict
def readCollectionFile(CONF, RANKED_FILE_CONTENT):
dic = {}
for query in RANKED_FILE_CONTENT:
for document in query:
if document[1] not in dic:
dic[document[1]] = ""
else:
continue
collectionDict = getDocumentContent(CONF, dic)
return collectionDict
def getDocumentContent(CONF, dic):
# Scan all 9 collection_jsonl files to find the corresponding contents
collectionPaths = CONF["COLLECTION"]
for path in tqdm(collectionPaths, desc='Loading Collection...'):
with open(path, "r") as f:
lines = f.readlines()
for line in lines:
jsonL = json.loads(line)
if jsonL["id"] in dic:
dic[jsonL["id"]] = jsonL["contents"]
else:
continue
return dic
def results_to_trecrun(folder_path, output, runid):
# combine res files in a folder into a single trec run file, ready to submit.
files = os.listdir(folder_path)
trecrun_file = open(output, "a")
for f in files:
if f == '.DS_Store':
pass
else:
file_path = folder_path+'/'+f
with open(file_path) as f:
for line in f:
line = line.strip()
topicid, docid, rank, score = line.split("\t")
line = "{} Q0 {} {} {} {}".format(
topicid, docid, rank, score, runid)
trecrun_file.write(line+"\n")
def results_to_msrun(folder_path, output):
# combine res files in a folder into a single trec run file, ready to submit.
files = os.listdir(folder_path)
trecrun_file = open(output, "a")
for f in files:
if f == '.DS_Store':
pass
else:
file_path = folder_path+'/'+f
with open(file_path) as f:
for line in f:
line = line.strip()
topicid, docid, rank, score = line.split(" ")
line = "{}\t{}\t{}\t{}".format(topicid, docid, rank, score)
trecrun_file.write(line+"\n")