diff --git a/gensim/similarities/docsim.py b/gensim/similarities/docsim.py index ac93119b1e..6f9db34526 100755 --- a/gensim/similarities/docsim.py +++ b/gensim/similarities/docsim.py @@ -639,7 +639,7 @@ def get_similarities(self, query): # Convert document indexes to actual documents. query = [self.corpus[i] for i in query] - if not isinstance(query[0], list): + if not query or not isinstance(query[0], list): query = [query] n_queries = len(query) @@ -724,7 +724,7 @@ def get_similarities(self, query): # Convert document indexes to actual documents. query = [self.corpus[i] for i in query] - if not isinstance(query[0], list): + if not query or not isinstance(query[0], list): query = [query] n_queries = len(query) diff --git a/gensim/test/test_similarities.py b/gensim/test/test_similarities.py index ded6f181fc..63d0e1aff9 100644 --- a/gensim/test/test_similarities.py +++ b/gensim/test/test_similarities.py @@ -40,6 +40,9 @@ class _TestSimilarityABC(object): Base class for SparseMatrixSimilarity and MatrixSimilarity unit tests. """ + def factoryMethod(self): + return self.cls(corpus, num_features=len(dictionary)) + def testFull(self, num_best=None, shardsize=100): if self.cls == similarities.Similarity: index = self.cls(None, corpus, num_features=len(dictionary), shardsize=shardsize) @@ -104,6 +107,15 @@ def test_scipy2scipy_clipped(self): self.assertTrue(scipy.sparse.issparse(matrix_scipy_clipped)) self.assertTrue([matutils.scipy2sparse(x) for x in matrix_scipy_clipped], [expected] * 3) + def testEmptyQuery(self): + index = self.factoryMethod() + query = [] + try: + sims = index[query] + self.assertTrue(sims is not None) + except IndexError: + self.assertTrue(False) + def testChunking(self): if self.cls == similarities.Similarity: index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) @@ -156,14 +168,7 @@ def testPersistency(self): return fname = get_tmpfile('gensim_similarities.tst.pkl') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - elif self.cls == similarities.SoftCosineSimilarity: - index = self.cls(self.corpus, self.similarity_matrix) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() index.save(fname) index2 = self.cls.load(fname) if self.cls == similarities.Similarity: @@ -183,14 +188,7 @@ def testPersistencyCompressed(self): return fname = get_tmpfile('gensim_similarities.tst.pkl.gz') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - elif self.cls == similarities.SoftCosineSimilarity: - index = self.cls(self.corpus, self.similarity_matrix) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() index.save(fname) index2 = self.cls.load(fname) if self.cls == similarities.Similarity: @@ -210,14 +208,7 @@ def testLarge(self): return fname = get_tmpfile('gensim_similarities.tst.pkl') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - elif self.cls == similarities.SoftCosineSimilarity: - index = self.cls(self.corpus, self.similarity_matrix) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() # store all arrays separately index.save(fname, sep_limit=0) @@ -239,14 +230,7 @@ def testLargeCompressed(self): return fname = get_tmpfile('gensim_similarities.tst.pkl.gz') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - elif self.cls == similarities.SoftCosineSimilarity: - index = self.cls(self.corpus, self.similarity_matrix) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() # store all arrays separately index.save(fname, sep_limit=0) @@ -268,14 +252,7 @@ def testMmap(self): return fname = get_tmpfile('gensim_similarities.tst.pkl') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - elif self.cls == similarities.SoftCosineSimilarity: - index = self.cls(self.corpus, self.similarity_matrix) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() # store all arrays separately index.save(fname, sep_limit=0) @@ -298,14 +275,7 @@ def testMmapCompressed(self): return fname = get_tmpfile('gensim_similarities.tst.pkl.gz') - if self.cls == similarities.Similarity: - index = self.cls(None, corpus, num_features=len(dictionary), shardsize=5) - elif self.cls == similarities.WmdSimilarity: - index = self.cls(texts, self.w2v_model) - elif self.cls == similarities.SoftCosineSimilarity: - index = self.cls(self.corpus, self.similarity_matrix) - else: - index = self.cls(corpus, num_features=len(dictionary)) + index = self.factoryMethod() # store all arrays separately index.save(fname, sep_limit=0) @@ -323,6 +293,9 @@ def setUp(self): self.cls = similarities.WmdSimilarity self.w2v_model = Word2Vec(texts, min_count=1) + def factoryMethod(self): + return self.cls(texts, self.w2v_model) + def testFull(self, num_best=None): # Override testFull. @@ -405,6 +378,9 @@ def setUp(self): similarity_matrix[dictionary.token2id["human"], dictionary.token2id["user"]] = 0.5 self.similarity_matrix = similarity_matrix.tocsc() + def factoryMethod(self): + return self.cls(self.corpus, self.similarity_matrix) + def testFull(self, num_best=None): # Override testFull. @@ -504,6 +480,9 @@ class TestSimilarity(unittest.TestCase, _TestSimilarityABC): def setUp(self): self.cls = similarities.Similarity + def factoryMethod(self): + return self.cls(None, corpus, num_features=len(dictionary), shardsize=5) + def testSharding(self): for num_best in [None, 0, 1, 9, 1000]: for shardsize in [1, 2, 9, 1000]: