-
Notifications
You must be signed in to change notification settings - Fork 0
/
semantic_search.py
131 lines (122 loc) · 4.92 KB
/
semantic_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from os import stat
from nltk.corpus.reader.panlex_swadesh import PanlexSwadeshCorpusReader
from math import sqrt
from bool_search import Bool_Search
from utils.data_process import Data
from utils.tf_idf import TF_IDF
from utils.img import show_image
import logging
import requests
class Semantic_Search(object):
# threshold is used to filter out documents which only include less than threshold * words in query
def __init__(self, tf_idf_table, header_tf_idf_table, dictionary):
self.ths = 0.5
self.res = 10
self.tf_idf = tf_idf_table
self.header_tf_idf = header_tf_idf_table
# dict constructed from database
self.dict = dictionary
def _gen_tf_(self, query):
query = Data.dump(query)
query_list = []
for word in query:
if word in self.dict:
query_list.append(self.dict[word])
tf = TF_IDF([query_list], None, None, None)
tf.gen_tf()
return tf.tf[0]
# add tf-idf length into calculation
def search(self, query, threshold=0.3, return_results=10, len_weight=0.3, header_weight=0.2):
self.ths = threshold
self.res = return_results
query_tf = self._gen_tf_(query)
best_rank = [(0, 0) for i in range(self.res)]
for docid in range(len(self.tf_idf)):
rescos, reslen = self.calcu(query_tf, self.tf_idf[docid])
headercos, headerlen = self.calcu(
query_tf, self.header_tf_idf[docid])
# res = header_weight*(len_weight*headerlen+(1-len_weight)*headercos) + \
# (1-header_weight)*(len_weight*reslen+(1-len_weight)*rescos)
res = (1-header_weight)*pow(rescos, 1-len_weight)*pow(reslen, len_weight) + \
header_weight*pow(headercos, 1-len_weight) * \
pow(headerlen, len_weight)
if res > best_rank[-1][0]:
# calculate relavance
best_rank[-1] = (res, docid)
best_rank.sort(reverse=True)
return best_rank
# calculate cos value and length of vectors
def calcu(self, query_tf, doc_tf_idf):
if len(query_tf) == 0:
return 0, 0
hit = 0
for wordid in query_tf.keys():
if wordid in doc_tf_idf:
hit += 1
if hit < self.ths*len(query_tf):
return 0, 0
dotsum = 0
len_q = 0
len_d = 0
for wordid, tf in query_tf.items():
len_q += tf*tf
if wordid in doc_tf_idf:
dotsum += tf*doc_tf_idf[wordid]
len_d += doc_tf_idf[wordid]*doc_tf_idf[wordid]
if len_q == 0 or len_d == 0:
return 0, 0
return dotsum/sqrt(len_q*len_d), sqrt(len_d)
# try numpy's dot
def load(path='output'):
import zstd
import pickle
with open(f'{path}/tf_idf_matrix.zstd', 'rb') as f:
tf_idf = zstd.decompress(f.read())
tf_idf = pickle.loads(tf_idf)
f.close()
with open(f'{path}/header_tf_idf_matrix.zstd', 'rb') as f:
header_tf_idf = zstd.decompress(f.read())
header_tf_idf = pickle.loads(header_tf_idf)
f.close()
with open(f'{path}/dictionary.zstd', 'rb') as f:
dictionary = zstd.decompress(f.read())
dictionary = pickle.loads(dictionary)
f.close()
with open(f'{path}/metadata.zstd', 'rb') as f:
metadata = zstd.decompress(f.read())
metadata = pickle.loads(metadata)
f.close()
return tf_idf, header_tf_idf, dictionary, metadata
if __name__ == '__main__':
# Beware that python API limit data size to 2GB
# Coz all source files' size = 1.9GB so we can ignore it safely
logging.info("Loading data from file")
tf_idf, header_tf_idf, dictionary, metadata = load()
ss = Semantic_Search(tf_idf, header_tf_idf, dictionary)
print("Ctrl + C to exit")
while True:
query = input("Enter words for semantic search: ")
try:
res = ss.search(query, 0.5, 10, 0.6)
except:
continue
if res[0][0] == 0:
print('Not found')
else:
for docid in range(len(res)):
if(res[docid][0] > 0):
imgurl = metadata[res[docid][1]]['img']
title = metadata[res[docid][1]]['title']
uid = metadata[res[docid][1]]['id']
print('{}:\t{}\t{}'.format(
res[docid][0], title, uid), end='\t')
if len(imgurl) > 0:
print('Found image')
try:
ret = show_image(imgurl, title)
if ret == False:
logging.error('No suitable image viewer found')
except requests.exceptions.RequestException as e:
logging.error('Image loading timeout')
else:
print('Image not found')