Skip to content

Commit

Permalink
Make WmdSimilarity and SoftCosineSimilarity handle empty queries
Browse files Browse the repository at this point in the history
  • Loading branch information
Witiko committed Jan 7, 2018
1 parent cdc5811 commit defaf4d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 50 deletions.
4 changes: 2 additions & 2 deletions gensim/similarities/docsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def get_similarities(self, query):
# Convert document indexes to actual documents.
query = [self.corpus[i] for i in query]

if not isinstance(query[0], list):
if not query or not isinstance(query[0], list):
query = [query]

n_queries = len(query)
Expand Down Expand Up @@ -724,7 +724,7 @@ def get_similarities(self, query):
# Convert document indexes to actual documents.
query = [self.corpus[i] for i in query]

if not isinstance(query[0], list):
if not query or not isinstance(query[0], list):
query = [query]

n_queries = len(query)
Expand Down
75 changes: 27 additions & 48 deletions gensim/test/test_similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class _TestSimilarityABC(object):
Base class for SparseMatrixSimilarity and MatrixSimilarity unit tests.
"""

def factoryMethod(self):
return self.cls(corpus, num_features=len(dictionary))

def testFull(self, num_best=None, shardsize=100):
if self.cls == similarities.Similarity:
index = self.cls(None, corpus, num_features=len(dictionary), shardsize=shardsize)
Expand Down Expand Up @@ -104,6 +107,15 @@ def test_scipy2scipy_clipped(self):
self.assertTrue(scipy.sparse.issparse(matrix_scipy_clipped))
self.assertTrue([matutils.scipy2sparse(x) for x in matrix_scipy_clipped], [expected] * 3)

def testEmptyQuery(self):
index = self.factoryMethod()
query = []
try:
sims = index[query]
self.assertTrue(sims is not None)
except IndexError:
self.assertTrue(False)

def testChunking(self):
if self.cls == similarities.Similarity:
index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5)
Expand Down Expand Up @@ -156,14 +168,7 @@ def testPersistency(self):
return

fname = get_tmpfile('gensim_similarities.tst.pkl')
if self.cls == similarities.Similarity:
index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5)
elif self.cls == similarities.WmdSimilarity:
index = self.cls(texts, self.w2v_model)
elif self.cls == similarities.SoftCosineSimilarity:
index = self.cls(self.corpus, self.similarity_matrix)
else:
index = self.cls(corpus, num_features=len(dictionary))
index = self.factoryMethod()
index.save(fname)
index2 = self.cls.load(fname)
if self.cls == similarities.Similarity:
Expand All @@ -183,14 +188,7 @@ def testPersistencyCompressed(self):
return

fname = get_tmpfile('gensim_similarities.tst.pkl.gz')
if self.cls == similarities.Similarity:
index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5)
elif self.cls == similarities.WmdSimilarity:
index = self.cls(texts, self.w2v_model)
elif self.cls == similarities.SoftCosineSimilarity:
index = self.cls(self.corpus, self.similarity_matrix)
else:
index = self.cls(corpus, num_features=len(dictionary))
index = self.factoryMethod()
index.save(fname)
index2 = self.cls.load(fname)
if self.cls == similarities.Similarity:
Expand All @@ -210,14 +208,7 @@ def testLarge(self):
return

fname = get_tmpfile('gensim_similarities.tst.pkl')
if self.cls == similarities.Similarity:
index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5)
elif self.cls == similarities.WmdSimilarity:
index = self.cls(texts, self.w2v_model)
elif self.cls == similarities.SoftCosineSimilarity:
index = self.cls(self.corpus, self.similarity_matrix)
else:
index = self.cls(corpus, num_features=len(dictionary))
index = self.factoryMethod()
# store all arrays separately
index.save(fname, sep_limit=0)

Expand All @@ -239,14 +230,7 @@ def testLargeCompressed(self):
return

fname = get_tmpfile('gensim_similarities.tst.pkl.gz')
if self.cls == similarities.Similarity:
index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5)
elif self.cls == similarities.WmdSimilarity:
index = self.cls(texts, self.w2v_model)
elif self.cls == similarities.SoftCosineSimilarity:
index = self.cls(self.corpus, self.similarity_matrix)
else:
index = self.cls(corpus, num_features=len(dictionary))
index = self.factoryMethod()
# store all arrays separately
index.save(fname, sep_limit=0)

Expand All @@ -268,14 +252,7 @@ def testMmap(self):
return

fname = get_tmpfile('gensim_similarities.tst.pkl')
if self.cls == similarities.Similarity:
index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5)
elif self.cls == similarities.WmdSimilarity:
index = self.cls(texts, self.w2v_model)
elif self.cls == similarities.SoftCosineSimilarity:
index = self.cls(self.corpus, self.similarity_matrix)
else:
index = self.cls(corpus, num_features=len(dictionary))
index = self.factoryMethod()
# store all arrays separately
index.save(fname, sep_limit=0)

Expand All @@ -298,14 +275,7 @@ def testMmapCompressed(self):
return

fname = get_tmpfile('gensim_similarities.tst.pkl.gz')
if self.cls == similarities.Similarity:
index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5)
elif self.cls == similarities.WmdSimilarity:
index = self.cls(texts, self.w2v_model)
elif self.cls == similarities.SoftCosineSimilarity:
index = self.cls(self.corpus, self.similarity_matrix)
else:
index = self.cls(corpus, num_features=len(dictionary))
index = self.factoryMethod()
# store all arrays separately
index.save(fname, sep_limit=0)

Expand All @@ -323,6 +293,9 @@ def setUp(self):
self.cls = similarities.WmdSimilarity
self.w2v_model = Word2Vec(texts, min_count=1)

def factoryMethod(self):
return self.cls(texts, self.w2v_model)

def testFull(self, num_best=None):
# Override testFull.

Expand Down Expand Up @@ -405,6 +378,9 @@ def setUp(self):
similarity_matrix[dictionary.token2id["human"], dictionary.token2id["user"]] = 0.5
self.similarity_matrix = similarity_matrix.tocsc()

def factoryMethod(self):
return self.cls(self.corpus, self.similarity_matrix)

def testFull(self, num_best=None):
# Override testFull.

Expand Down Expand Up @@ -504,6 +480,9 @@ class TestSimilarity(unittest.TestCase, _TestSimilarityABC):
def setUp(self):
self.cls = similarities.Similarity

def factoryMethod(self):
return self.cls(None, corpus, num_features=len(dictionary), shardsize=5)

def testSharding(self):
for num_best in [None, 0, 1, 9, 1000]:
for shardsize in [1, 2, 9, 1000]:
Expand Down

0 comments on commit defaf4d

Please sign in to comment.