Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Soft Cosine Measure #1827

Merged
merged 40 commits into from
Feb 8, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c732666
Implement Soft Cosine Similarity
Witiko Jan 4, 2018
4c66255
Added numpy-style documentation for Soft Cosine Similarity
Witiko Jan 6, 2018
fd14514
Added unit tests for Soft Cosine Similarity
Witiko Jan 7, 2018
fc31b65
Make WmdSimilarity and SoftCosineSimilarity handle empty queries
Witiko Jan 7, 2018
81cb81b
Merge remote-tracking branch 'upstream/develop' into softcossim
Witiko Jan 28, 2018
40a2a34
Rename Soft Cosine Similarity to Soft Cosine Measure
Witiko Jan 28, 2018
50a7274
Add links to Soft Cosine Measure papers
Witiko Jan 28, 2018
08dea4e
Remove unused variables and parameters for Soft Cosine Measure
Witiko Jan 28, 2018
8af5f67
Replace explicit timers with magic %time in Soft Cosine Measure notebook
Witiko Jan 28, 2018
effef71
Rename var in term similarity matrix construction to reflect symmetry
Witiko Jan 28, 2018
3cf154e
Update SoftCosineSimilarity class example to define all variables
Witiko Jan 28, 2018
621ed0d
Make the code in Soft Cosine Measure notebook more compact
Witiko Jan 28, 2018
e1eb7cd
Use hanging indents in EuclideanKeyedVectors.similarity_matrix
Witiko Jan 28, 2018
03a967b
Simplified expressions in WmdSimilarity and SoftCosineSimilarity
Witiko Jan 28, 2018
5e3973e
Extract the sparse2coo function to the global scope
Witiko Jan 28, 2018
9782782
Fix __str__ of SoftCosineSimilarity
Witiko Jan 28, 2018
c7f0ce1
Use hanging indents in SoftCossim.__init__
Witiko Jan 28, 2018
6cfa472
Fix formatting of the matutils module
Witiko Jan 29, 2018
bf36770
Make similarity matrix info messages appear at fixed frequency
Witiko Jan 31, 2018
710a399
Construct term similarity matrix rows for important terms first
Witiko Feb 1, 2018
0752712
Optimize softcossim for an estimated 100-fold constant speed increase
Witiko Feb 2, 2018
2d5869d
Merge remote-tracking branch 'upstream/develop' into HEAD
Witiko Feb 2, 2018
a7ccabd
Remove unused import in gensim.similarities.docsim
Witiko Feb 2, 2018
02e6a8a
Fix imports in gensim.models.keyedvectors
Witiko Feb 2, 2018
980874b
replace reference to anonymous link
menshikh-iv Feb 5, 2018
b6a635a
Update "See Also" references to new *2vec implementation
Witiko Feb 5, 2018
640795e
Fix formatting error in gensim.models.keyedvectors
Witiko Feb 5, 2018
c83b377
Update Soft Cosine Measure tutorial notebook
Witiko Feb 5, 2018
3148b60
Update Soft Cosine Measure tutorial notebook
Witiko Feb 5, 2018
61a11d4
Use smaller glove-wiki-gigaword-50 model in Soft Cosine Measure notebook
Witiko Feb 6, 2018
e4f80eb
Use gensim-data to load SemEval datasets in Soft Cosine Measure notebook
Witiko Feb 6, 2018
418ae08
Use backwards-compatible syntax in Soft Cosine Similarity notebook
Witiko Feb 6, 2018
86044d4
Remove unnecessary package requirements in Soft Cosine Measure notebook
Witiko Feb 6, 2018
3ca0f8c
Fix Soft Cosine Measure notebook to use true gensim-data dataset names
Witiko Feb 6, 2018
55215f5
fix docs[1]
menshikh-iv Feb 8, 2018
c564b58
fix docs[2]
menshikh-iv Feb 8, 2018
69ec746
fix docs[3]
menshikh-iv Feb 8, 2018
70484c1
small fixes
menshikh-iv Feb 8, 2018
d6afdc2
small fixes[2]
menshikh-iv Feb 8, 2018
326ee66
Merge remote-tracking branch 'origin/softcossim' into softcossim
menshikh-iv Feb 8, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
613 changes: 613 additions & 0 deletions docs/notebooks/soft_cosine_tutorial.ipynb

Large diffs are not rendered by default.

Binary file added docs/notebooks/soft_cosine_tutorial.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 76 additions & 0 deletions gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,82 @@ def cossim(vec1, vec2):
return result


def softcossim(vec1, vec2, similarity_matrix):
"""Return Soft Cosine Similarity between two vectors given a term similarity matrix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation looks really nice (only small tips that I'll fix myself when all will be ready from your side) 👍 🔥


Return Soft Cosine Similarity between two sparse vectors given a sparse term similarity matrix
in the `scipy.sparse.csc_matrix` format. The similarity is a number between <-1.0, 1.0>, higher
is more similar.

Parameters
----------
vec1 : list of (int, float) two-tuples
A query vector in the gensim document format.
vec2 : list of (int, float) two-tuples of ints
A document vector in the gensim document format.
similarity_matrix : scipy.sparse.csc_matrix
A term similarity matrix.

Returns
-------
similarity_matrix.dtype
The Soft Cosine Similarity between `vec1` and `vec2`.

Raises
------
ValueError
When the term similarity matrix is in an unknown format.

See Also
--------
gensim.models.keyedvectors.EuclideanKeyedVectors.similarity_matrix
A term similarity matrix produced from term embeddings.
gensim.similarities.docsim.SoftCosineSimilarity
A class for performing corpus-based similarity queries with Soft Cosine Similarity.

References
----------
Soft Cosine Similarity was perhaps first defined by [sidorovetal14]_.

.. [sidorovetal14] Grigori Sidorov et al., "Soft Similarity and Soft Cosine Measure: Similarity
of Features in Vector Space Model", 2014.
"""

def sparse2coo(vec):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no function for this in Gensim?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, better to define it not in function (because this function can be useful not only here), please define in outside softcossim

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted the sparse2coo function to the global scope in 5e3973e. The existing corpus2csc function could be (mis)used for this purpose, but this yielded a significant slowdown i my tests.

col = [0] * len(vec)
row, data = zip(*vec)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(similarity_matrix.shape[0], 1),
dtype=similarity_matrix.dtype)

def softdot(vec1, vec2):
vec1 = vec1.tocsr()
vec2 = vec2.tocsc()
return (vec1.T).dot(similarity_matrix).dot(vec2)[0, 0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's shape of (vec1.T).dot(similarity_matrix).dot(vec2) (slightly confused becase [0, 0])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape is (1, 1), i.e. a single scalar value packed as a sparse matrix.


if not isinstance(similarity_matrix, scipy.sparse.csc_matrix):
if isinstance(similarity_matrix, scipy.sparse.csr_matrix):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's about other formats (sparse, but not csc or csr)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These would incur a performance hit, because we first need to convert these to csc / csr. We can either issue a warning and try to convert anything remotely meaningful to the csc / csr format, which will take a time linear in the number of nonzero entries in similarity_matrix and temporarily hold two copies of the (potentially huge) matrix, or we can just refuse anything that is in the wrong format. In the code, I chose the latter approach.

similarity_matrix = similarity_matrix.T
else:
raise ValueError('unknown similarity matrix format')

if not vec1 or not vec2:
return 0.0
vec1 = sparse2coo(vec1)
vec2 = sparse2coo(vec2)
vec1len = softdot(vec1, vec1)
vec2len = softdot(vec2, vec2)
assert vec1len > 0.0 and vec2len > 0.0, u"sparse documents must not contain any explicit zero" \
" entries and the similarity matrix S must satisfy x^T * S * x > 0 for any nonzero" \
" bag-of-words vector x."
result = softdot(vec1, vec2)
result /= math.sqrt(vec1len) * math.sqrt(vec2len) # rescale by vector lengths
if result > 1.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for avoiding problems with float (can be 1.0000003 or something similar)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is.

return 1.0
if result < -1.0:
return -1.0
return result


def isbow(vec):
"""
Checks if vector passed is in bag of words representation or not.
Expand Down
87 changes: 86 additions & 1 deletion gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from __future__ import division # py3 "true division"

import logging
from math import ceil

try:
from queue import Queue, Empty
Expand All @@ -81,7 +82,7 @@
from gensim.corpora.dictionary import Dictionary
from six import string_types, iteritems
from six.moves import xrange
from scipy import stats
from scipy import stats, sparse


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -559,6 +560,90 @@ def similar_by_vector(self, vector, topn=10, restrict_vocab=None):
"""
return self.most_similar(positive=[vector], topn=topn, restrict_vocab=restrict_vocab)

def similarity_matrix(self, corpus, dictionary, threshold=0.0, exponent=2.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corpus isn't used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 08dea4e.

nonzero_limit=100, dtype=REAL):
"""Constructs a term similarity matrix for computing Soft Cosine Similarity.

Constructs a a sparse term similarity matrix in the `scipy.sparse.csc_matrix` format for
computing Soft Cosine Similarity between documents.

Parameters
----------
corpus : list of lists of (int, float) two-tuples
A list of documents in the gensim document format.
dictionary : gensim.corpora.Dictionary
A dictionary associated with the corpus.
threshold : float, optional
Only pairs of words whose embeddings are more similar than `threshold` are considered
when building the sparse term similarity matrix. Defaults to `0.0`.
exponent : float, optional
The exponent applied to the similarity between two word embeddings when building the
term similarity matrix. Defaults to `2.0`.
nonzero_limit : int, optional
The maximum number of non-zero elements outside the diagonal in a single row or column
of the term similarity matrix. Setting `nonzero_limit` to a constant ensures that the
time complexity of computing the Soft Cosine Similarity will be linear in the document
length rather than quadratic. Defaults to `100`.
dtype : numpy.dtype, optional
Data-type of the term similarity matrix. Defaults to `numpy.float32`.

Returns
-------
scipy.sparse.csc_matrix
The constructed term similarity matrix.

See Also
--------
gensim.matutils.softcossim
The Soft Cosine Similarity.
gensim.similarities.docsim.SoftCosineSimilarity
A class for performing corpus-based similarity queries with Soft Cosine Similarity.

References
----------
The constructed matrix corresponds to the matrix Mrel defined in section 2.1 of
[charletdamnati17]_.

.. [charletdamnati17] Delphine Charlet and Geraldine Damnati, "SimBow at SemEval-2017
Task 3: Soft-Cosine Semantic Similarity between Questions for Community Question
Answering", 2017.
"""

logger.info("constructing a term similarity matrix")
similarity_matrix = sparse.identity(len(dictionary), dtype=dtype, format="lil")
# Traverse rows.
num_rows = len(dictionary)
for w1_index in range(num_rows):
if w1_index % ceil(num_rows / 10) == 0:
logger.info("PROGRESS: at row %i / %i", w1_index + 1, num_rows)
w1 = dictionary[w1_index]
if w1 not in self.vocab:
continue # A word from the dictionary not present in the word2vec model.
# Traverse upper triangle columns.
if len(dictionary) <= nonzero_limit + 1: # Traverse all columns.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if num_rows instead of len(dictionary) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in effef71.

columns = ((w2_index, self.similarity(w1, dictionary[w2_index]))
for w2_index in range(w1_index + 1, num_rows)
if w1_index != w2_index and dictionary[w2_index] in self.vocab)
else: # Traverse only columns corresponding to the embeddings closest to w1.
num_nonzero = similarity_matrix[w1_index].getnnz() - 1
columns = ((dictionary.token2id[w2], similarity)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use hanging indents (instead of vertical).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e1eb7cd and c7f0ce1.

for _, (w2, similarity)
in zip(range(nonzero_limit - num_nonzero),
self.most_similar(positive=[w1],
topn=nonzero_limit - num_nonzero))
if w2 in dictionary.token2id and w1_index < dictionary.token2id[w2])
columns = sorted(columns, key=lambda x: x[0])
for w2_index, similarity in columns:
assert w1_index < w2_index
# Ensure that we don't exceed `nonzero_limit` by mirroring the upper triangle.
if similarity > threshold and similarity_matrix[w2_index].getnnz() <= nonzero_limit:
element = similarity**exponent
similarity_matrix[w1_index, w2_index] = element
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarity_matrix is symmetrical, maybe better to store only "half" of this matrix -> reduce memory usage twice?

Copy link
Contributor Author

@Witiko Witiko Jan 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that such a saving would be nice, but there seems to be no support in SciPy for a dot product of a symmetrical matrix that only stores the upper / lower triangle ((vec1.T).dot(similarity_matrix).dot(vec2)[0, 0]). Even beyond SciPy, I don't know of a sparse matrix format that would allow both row-wise and column-wise efficient access.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sad but true, thanks for the clarification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this knowledge can still be useful when storing and transmitting the similarity matrix.

similarity_matrix[w2_index, w1_index] = element
logger.info("constructed a term similarity matrix with %0.2f %% nonzero entries",
100.0 * similarity_matrix.getnnz() / num_rows**2)
return similarity_matrix.tocsc()

def wmdistance(self, document1, document2):
"""
Compute the Word Mover's Distance between two documents. When using this
Expand Down
2 changes: 1 addition & 1 deletion gensim/similarities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

# bring classes directly into package namespace, to save some typing
from .docsim import Similarity, MatrixSimilarity, SparseMatrixSimilarity, WmdSimilarity # noqa:F401
from .docsim import Similarity, MatrixSimilarity, SparseMatrixSimilarity, SoftCosineSimilarity, WmdSimilarity # noqa:F401
104 changes: 103 additions & 1 deletion gensim/similarities/docsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import itertools
import os
import heapq
import warnings

import numpy
import scipy.sparse
Expand Down Expand Up @@ -563,6 +564,107 @@ def __str__(self):
return "%s<%i docs, %i features>" % (self.__class__.__name__, len(self), self.index.shape[1])


class SoftCosineSimilarity(interfaces.SimilarityABC):
"""Document similarity (like MatrixSimilarity) that uses Soft Cosine Similarity as a similarity
measure.

Parameters
----------
corpus: list of lists of (int, float) two-tuples
A list of documents in the gensim document format.
similarity_matrix : scipy.sparse.csc_matrix
A term similarity matrix.
num_best : int
The number of results to retrieve for a query.

See Also
--------
gensim.models.keyedvectors.EuclideanKeyedVectors.similarity_matrix
A term similarity matrix produced from term embeddings.
gensim.matutils.softcossim
The Soft Cosine Similarity.

Examples
--------
>>> from gensim.models import Word2Vec
>>> from gensim.similarities import SoftCosineSimilarity
>>> from gensim.utils import simple_preprocess
>>> from gensim.corpora import Dictionary
>>> # Given a document collection "corpus", train a word2vec model.
>>> model = Word2Vec(corpus, workers=3, size=100)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all variables should be defined (this is really important for examples). You can give small corpus from gensim.test.utils or use, for example, this one

import gensim.downloader as api
corpus = api.load("text8")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 3cf154e.

>>> # Construct a bag-of-words corpus, a dictionary, and a term similarity matrix.
>>> dictionary = Dictionary(corpus)
>>> corpus = [dictionary.doc2bow(document) for document in corpus]
>>> similarity_matrix = model.wv.similarity_matrix(corpus, dictionary)
>>> index = SoftCosineSimilarity(corpus, similarity_matrix, num_best=10)
>>> # Make a query.
>>> query = 'Yummy! Great view of the Bellagio Fountain show.'
>>> sims = index[dictionary.doc2bow(simple_preprocess(query))]

See `Tutorial Notebook
<https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/soft_cosine_tutorial.ipynb>`_
for more examples.
"""

def __init__(self, corpus, similarity_matrix, num_best=None, chunksize=256):
self.corpus = corpus
self.num_best = num_best
self.chunksize = chunksize

# Normalization of features is undesirable, since soft cosine similarity requires special
# normalization using the similarity matrix. Therefore, we would just be normalizing twice,
# increasing the numerical error.
self.normalize = False

# index is simply an array from 0 to size of corpus.
self.index = numpy.array(range(len(corpus)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy.arange(len(corpus))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should corpus always have __len__?

For example, if corpus something like this

class MyCorpus(object):
    def __init__(self, fname):
        self.fname = fname
    def __iter__(self):
        with smart_open(self.fname) as infile:
            for line in infile:
                yield line

Copy link
Contributor Author

@Witiko Witiko Jan 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was taken straight from the WmdSimilarity class. Note that we also assume that a corpus has __getitem__ (https://github.com/Witiko/gensim/blob/fc31b659f0d635c827e505b6accb17fd3f8b1925/gensim/similarities/docsim.py#L640) in both WmdSimilarity and SoftCosineSimilarity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first concern has been fixed in 03a967b.


# Remove the columns of the similarity matrix that correspond to terms outside corpus.
nonzero_columns = sorted(set((index for document in corpus for index, _ in document)))
identity_matrix = scipy.sparse.identity(similarity_matrix.shape[0],
dtype=similarity_matrix.dtype, format="csr")
with warnings.catch_warnings():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this block needed?

Copy link
Contributor Author

@Witiko Witiko Jan 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As stated in the comment, the intent is to “remove the columns of the similarity matrix that correspond to terms outside corpus”. This will create a smaller matrix with zeroed-out columns that never appear in the corpus and that will therefore never be accessed in a dot product. Due to the deficiencies in the SciPy dot product implementation, the subsequent calls to gensim.matlibs.softcossim will also be faster with such a matrix; an optimized dot product code would not access these columns either way.

warnings.simplefilter("ignore", scipy.sparse.SparseEfficiencyWarning)
identity_matrix[nonzero_columns] = similarity_matrix.T[nonzero_columns]
self.similarity_matrix = identity_matrix.T

def __len__(self):
return len(self.corpus)

def get_similarities(self, query):
"""
**Do not use this function directly; use the self[query] syntax instead.**
"""
if isinstance(query, numpy.ndarray):
# Convert document indexes to actual documents.
query = [self.corpus[i] for i in query]

if not query or not isinstance(query[0], list):
query = [query]

n_queries = len(query)
result = []
for qidx in range(n_queries):
# Compute similarity for each query.
qresult = [matutils.softcossim(document, query[qidx], self.similarity_matrix)
for document in self.corpus]
qresult = numpy.array(qresult)

# Append single query result to list of all results.
result.append(qresult)

if len(result) == 1:
# Only one query.
result = result[0]
else:
result = numpy.array(result)

return result

def __str__(self):
return "%s<%i docs, %i features>" % (self.__class__.__name__, len(self), self.w2v_model.wv.syn0.shape[1])


class WmdSimilarity(interfaces.SimilarityABC):
"""
Document similarity (like MatrixSimilarity) that uses the negative of WMD
Expand Down Expand Up @@ -622,7 +724,7 @@ def get_similarities(self, query):
# Convert document indexes to actual documents.
query = [self.corpus[i] for i in query]

if not isinstance(query[0], list):
if not query or not isinstance(query[0], list):
query = [query]

n_queries = len(query)
Expand Down
32 changes: 32 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

from gensim.corpora import Dictionary
from gensim.models.keyedvectors import EuclideanKeyedVectors
from gensim.test.utils import datapath

Expand All @@ -26,6 +27,37 @@ def setUp(self):
self.vectors = EuclideanKeyedVectors.load_word2vec_format(
datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64)

def similarity_matrix(self):
"""Test similarity_matrix returns expected results."""

corpus = [["government", "denied", "holiday"], ["holiday", "slowing", "hollingworth"]]
dictionary = Dictionary(corpus)
corpus = [dictionary.doc2bow(document) for document in corpus]

# checking symmetry and the existence of ones on the diagonal
similarity_matrix = self.similarity_matrix(corpus, dictionary).todense()
self.assertTrue((similarity_matrix.T == similarity_matrix).all())
self.assertTrue((np.diag(similarity_matrix) == similarity_matrix).all())

# checking that thresholding works as expected
similarity_matrix = self.similarity_matrix(corpus, dictionary, threshold=0.45).todense()
expected = 18
self.assertEquals(expected, np.sum(similarity_matrix == 0))

# checking that exponent works as expected
similarity_matrix = self.similarity_matrix(corpus, dictionary, exponent=1.0).todense()
expected = 9.5788956
self.assertAlmostEqual(expected, np.sum(similarity_matrix))

# checking that nonzero_limit works as expected
similarity_matrix = self.similarity_matrix(corpus, dictionary, nonzero_limit=4).todense()
expected = 4
self.assertEquals(expected, np.sum(similarity_matrix == 0))

similarity_matrix = self.similarity_matrix(corpus, dictionary, nonzero_limit=3).todense()
expected = 20
self.assertEquals(expected, np.sum(similarity_matrix == 0))

def test_most_similar(self):
"""Test most_similar returns expected results."""
expected = [
Expand Down
Loading