Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Soft Cosine Measure #1827

Merged
merged 40 commits into from
Feb 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c732666
Implement Soft Cosine Similarity
Witiko Jan 4, 2018
4c66255
Added numpy-style documentation for Soft Cosine Similarity
Witiko Jan 6, 2018
fd14514
Added unit tests for Soft Cosine Similarity
Witiko Jan 7, 2018
fc31b65
Make WmdSimilarity and SoftCosineSimilarity handle empty queries
Witiko Jan 7, 2018
81cb81b
Merge remote-tracking branch 'upstream/develop' into softcossim
Witiko Jan 28, 2018
40a2a34
Rename Soft Cosine Similarity to Soft Cosine Measure
Witiko Jan 28, 2018
50a7274
Add links to Soft Cosine Measure papers
Witiko Jan 28, 2018
08dea4e
Remove unused variables and parameters for Soft Cosine Measure
Witiko Jan 28, 2018
8af5f67
Replace explicit timers with magic %time in Soft Cosine Measure notebook
Witiko Jan 28, 2018
effef71
Rename var in term similarity matrix construction to reflect symmetry
Witiko Jan 28, 2018
3cf154e
Update SoftCosineSimilarity class example to define all variables
Witiko Jan 28, 2018
621ed0d
Make the code in Soft Cosine Measure notebook more compact
Witiko Jan 28, 2018
e1eb7cd
Use hanging indents in EuclideanKeyedVectors.similarity_matrix
Witiko Jan 28, 2018
03a967b
Simplified expressions in WmdSimilarity and SoftCosineSimilarity
Witiko Jan 28, 2018
5e3973e
Extract the sparse2coo function to the global scope
Witiko Jan 28, 2018
9782782
Fix __str__ of SoftCosineSimilarity
Witiko Jan 28, 2018
c7f0ce1
Use hanging indents in SoftCossim.__init__
Witiko Jan 28, 2018
6cfa472
Fix formatting of the matutils module
Witiko Jan 29, 2018
bf36770
Make similarity matrix info messages appear at fixed frequency
Witiko Jan 31, 2018
710a399
Construct term similarity matrix rows for important terms first
Witiko Feb 1, 2018
0752712
Optimize softcossim for an estimated 100-fold constant speed increase
Witiko Feb 2, 2018
2d5869d
Merge remote-tracking branch 'upstream/develop' into HEAD
Witiko Feb 2, 2018
a7ccabd
Remove unused import in gensim.similarities.docsim
Witiko Feb 2, 2018
02e6a8a
Fix imports in gensim.models.keyedvectors
Witiko Feb 2, 2018
980874b
replace reference to anonymous link
menshikh-iv Feb 5, 2018
b6a635a
Update "See Also" references to new *2vec implementation
Witiko Feb 5, 2018
640795e
Fix formatting error in gensim.models.keyedvectors
Witiko Feb 5, 2018
c83b377
Update Soft Cosine Measure tutorial notebook
Witiko Feb 5, 2018
3148b60
Update Soft Cosine Measure tutorial notebook
Witiko Feb 5, 2018
61a11d4
Use smaller glove-wiki-gigaword-50 model in Soft Cosine Measure notebook
Witiko Feb 6, 2018
e4f80eb
Use gensim-data to load SemEval datasets in Soft Cosine Measure notebook
Witiko Feb 6, 2018
418ae08
Use backwards-compatible syntax in Soft Cosine Similarity notebook
Witiko Feb 6, 2018
86044d4
Remove unnecessary package requirements in Soft Cosine Measure notebook
Witiko Feb 6, 2018
3ca0f8c
Fix Soft Cosine Measure notebook to use true gensim-data dataset names
Witiko Feb 6, 2018
55215f5
fix docs[1]
menshikh-iv Feb 8, 2018
c564b58
fix docs[2]
menshikh-iv Feb 8, 2018
69ec746
fix docs[3]
menshikh-iv Feb 8, 2018
70484c1
small fixes
menshikh-iv Feb 8, 2018
d6afdc2
small fixes[2]
menshikh-iv Feb 8, 2018
326ee66
Merge remote-tracking branch 'origin/softcossim' into softcossim
menshikh-iv Feb 8, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
593 changes: 593 additions & 0 deletions docs/notebooks/soft_cosine_tutorial.ipynb

Large diffs are not rendered by default.

Binary file added docs/notebooks/soft_cosine_tutorial.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
72 changes: 72 additions & 0 deletions gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import with_statement


from itertools import chain
import logging
import math

Expand Down Expand Up @@ -755,6 +756,77 @@ def cossim(vec1, vec2):
return result


def softcossim(vec1, vec2, similarity_matrix):
"""Get Soft Cosine Measure between two vectors given a term similarity matrix.

Return Soft Cosine Measure between two sparse vectors given a sparse term similarity matrix
in the :class:`scipy.sparse.csc_matrix` format. The similarity is a number between <-1.0, 1.0>,
higher is more similar.

Parameters
----------
vec1 : list of (int, float)
A query vector in the BoW format.
vec2 : list of (int, float)
A document vector in the BoW format.
similarity_matrix : {:class:`scipy.sparse.csc_matrix`, :class:`scipy.sparse.csr_matrix`}
A term similarity matrix, typically produced by
:meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix`.

Returns
-------
`similarity_matrix.dtype`
The Soft Cosine Measure between `vec1` and `vec2`.

Raises
------
ValueError
When the term similarity matrix is in an unknown format.

See Also
--------
:meth:`gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix`
A term similarity matrix produced from term embeddings.
:class:`gensim.similarities.docsim.SoftCosineSimilarity`
A class for performing corpus-based similarity queries with Soft Cosine Measure.

References
----------
Soft Cosine Measure was perhaps first defined by [sidorovetal14]_.

.. [sidorovetal14] Grigori Sidorov et al., "Soft Similarity and Soft Cosine Measure: Similarity
of Features in Vector Space Model", 2014, http://www.cys.cic.ipn.mx/ojs/index.php/CyS/article/view/2043/1921.

"""
if not isinstance(similarity_matrix, scipy.sparse.csc_matrix):
if isinstance(similarity_matrix, scipy.sparse.csr_matrix):
similarity_matrix = similarity_matrix.T
else:
raise ValueError('unknown similarity matrix format')

if not vec1 or not vec2:
return 0.0

vec1 = dict(vec1)
vec2 = dict(vec2)
word_indices = sorted(set(chain(vec1, vec2)))
dtype = similarity_matrix.dtype
vec1 = np.array([vec1[i] if i in vec1 else 0 for i in word_indices], dtype=dtype)
vec2 = np.array([vec2[i] if i in vec2 else 0 for i in word_indices], dtype=dtype)
dense_matrix = similarity_matrix[[[i] for i in word_indices], word_indices].todense()
vec1len = vec1.T.dot(dense_matrix).dot(vec1)[0, 0]
vec2len = vec2.T.dot(dense_matrix).dot(vec2)[0, 0]

assert \
vec1len > 0.0 and vec2len > 0.0, \
u"sparse documents must not contain any explicit zero entries and the similarity matrix S " \
u"must satisfy x^T * S * x > 0 for any nonzero bag-of-words vector x."

result = vec1.T.dot(dense_matrix).dot(vec2)[0, 0]
result /= math.sqrt(vec1len) * math.sqrt(vec2len) # rescale by vector lengths
return np.clip(result, -1.0, 1.0)


def isbow(vec):
"""Checks if vector passed is in BoW format.

Expand Down
113 changes: 110 additions & 3 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
from gensim.corpora.dictionary import Dictionary
from six import string_types, integer_types
from six.moves import xrange, zip
from scipy import stats
from scipy import sparse, stats
from gensim.utils import deprecated
from gensim.models.utils_any2vec import _save_word2vec_format, _load_word2vec_format, _compute_ngrams

Expand Down Expand Up @@ -191,8 +191,8 @@ def rank(self, entity1, entity2):


class WordEmbeddingsKeyedVectors(BaseKeyedVectors):
"""Class containing common methods for operations over word vectors.
"""
"""Class containing common methods for operations over word vectors."""

def __init__(self, vector_size):
super(WordEmbeddingsKeyedVectors, self).__init__(vector_size=vector_size)
self.vectors_norm = None
Expand Down Expand Up @@ -432,6 +432,113 @@ def similar_by_vector(self, vector, topn=10, restrict_vocab=None):
"""
return self.most_similar(positive=[vector], topn=topn, restrict_vocab=restrict_vocab)

def similarity_matrix(self, dictionary, tfidf=None, threshold=0.0, exponent=2.0, nonzero_limit=100, dtype=REAL):
"""Constructs a term similarity matrix for computing Soft Cosine Measure.

Constructs a a sparse term similarity matrix in the :class:`scipy.sparse.csc_matrix` format for computing
Soft Cosine Measure between documents.

Parameters
----------
dictionary : :class:`~gensim.corpora.dictionary.Dictionary`
A dictionary that specifies a mapping between words and the indices of rows and columns
of the resulting term similarity matrix.
tfidf : :class:`gensim.models.tfidfmodel.TfidfModel`, optional
A model that specifies the relative importance of the terms in the dictionary. The rows
of the term similarity matrix will be build in an increasing order of importance of terms,
or in the order of term identifiers if None.
threshold : float, optional
Only pairs of words whose embeddings are more similar than `threshold` are considered
when building the sparse term similarity matrix.
exponent : float, optional
The exponent applied to the similarity between two word embeddings when building the term similarity matrix.
nonzero_limit : int, optional
The maximum number of non-zero elements outside the diagonal in a single row or column
of the term similarity matrix. Setting `nonzero_limit` to a constant ensures that the
time complexity of computing the Soft Cosine Measure will be linear in the document
length rather than quadratic.
dtype : numpy.dtype, optional
Data-type of the term similarity matrix.

Returns
-------
:class:`scipy.sparse.csc_matrix`
Term similarity matrix.

See Also
--------
:func:`gensim.matutils.softcossim`
The Soft Cosine Measure.
:class:`gensim.similarities.docsim.SoftCosineSimilarity`
A class for performing corpus-based similarity queries with Soft Cosine Measure.


Notes
-----
The constructed matrix corresponds to the matrix Mrel defined in section 2.1 of
`Delphine Charlet and Geraldine Damnati, "SimBow at SemEval-2017 Task 3: Soft-Cosine Semantic Similarity
between Questions for Community Question Answering", 2017
<http://www.aclweb.org/anthology/S/S17/S17-2051.pdf>`__.

"""
logger.info("constructing a term similarity matrix")
matrix_order = len(dictionary)
matrix_nonzero = [1] * matrix_order
matrix = sparse.identity(matrix_order, dtype=dtype, format="dok")
num_skipped = 0
# Decide the order of rows.
if tfidf is None:
word_indices = range(matrix_order)
else:
assert max(tfidf.idfs) < matrix_order
word_indices = [
index for index, _ in sorted(tfidf.idfs.items(), key=lambda x: x[1], reverse=True)
]

# Traverse rows.
for row_number, w1_index in enumerate(word_indices):
if row_number % 1000 == 0:
logger.info(
"PROGRESS: at %.02f%% rows (%d / %d, %d skipped, %.06f%% density)",
100.0 * (row_number + 1) / matrix_order, row_number + 1, matrix_order,
num_skipped, 100.0 * matrix.getnnz() / matrix_order**2)
w1 = dictionary[w1_index]
if w1 not in self.vocab:
num_skipped += 1
continue # A word from the dictionary is not present in the word2vec model.
# Traverse upper triangle columns.
if matrix_order <= nonzero_limit + 1: # Traverse all columns.
columns = (
(w2_index, self.similarity(w1, dictionary[w2_index]))
for w2_index in range(w1_index + 1, matrix_order)
if w1_index != w2_index and dictionary[w2_index] in self.vocab)
else: # Traverse only columns corresponding to the embeddings closest to w1.
num_nonzero = matrix_nonzero[w1_index] - 1
columns = (
(dictionary.token2id[w2], similarity)
for _, (w2, similarity)
in zip(
range(nonzero_limit - num_nonzero),
self.most_similar(positive=[w1], topn=nonzero_limit - num_nonzero)
)
if w2 in dictionary.token2id
)
columns = sorted(columns, key=lambda x: x[0])

for w2_index, similarity in columns:
# Ensure that we don't exceed `nonzero_limit` by mirroring the upper triangle.
if similarity > threshold and matrix_nonzero[w2_index] <= nonzero_limit:
element = similarity**exponent
matrix[w1_index, w2_index] = element
matrix_nonzero[w1_index] += 1
matrix[w2_index, w1_index] = element
matrix_nonzero[w2_index] += 1
logger.info(
"constructed a term similarity matrix with %0.6f %% nonzero elements",
100.0 * matrix.getnnz() / matrix_order**2
)
return matrix.tocsc()

def wmdistance(self, document1, document2):
"""
Compute the Word Mover's Distance between two documents. When using this
Expand Down
2 changes: 1 addition & 1 deletion gensim/similarities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

# bring classes directly into package namespace, to save some typing
from .docsim import Similarity, MatrixSimilarity, SparseMatrixSimilarity, WmdSimilarity # noqa:F401
from .docsim import Similarity, MatrixSimilarity, SparseMatrixSimilarity, SoftCosineSimilarity, WmdSimilarity # noqa:F401
106 changes: 104 additions & 2 deletions gensim/similarities/docsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,108 @@ def __str__(self):
return "%s<%i docs, %i features>" % (self.__class__.__name__, len(self), self.index.shape[1])


class SoftCosineSimilarity(interfaces.SimilarityABC):
"""Document similarity (like MatrixSimilarity) that uses Soft Cosine Measure as a similarity measure."""

def __init__(self, corpus, similarity_matrix, num_best=None, chunksize=256):
"""

Parameters
----------
corpus: iterable of list of (int, float)
A list of documents in the BoW format.
similarity_matrix : :class:`scipy.sparse.csc_matrix`
A term similarity matrix, typically produced by
:meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix`.
num_best : int, optional
The number of results to retrieve for a query, if None - return similarities with all elements from corpus.
chunksize: int, optional
Size of one corpus chunk.


See Also
--------
:meth:`gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix`
A term similarity matrix produced from term embeddings.
:func:`gensim.matutils.softcossim`
The Soft Cosine Measure.

Examples
--------
>>> from gensim.corpora import Dictionary
>>> import gensim.downloader as api
>>> from gensim.models import Word2Vec
>>> from gensim.similarities import SoftCosineSimilarity
>>> from gensim.utils import simple_preprocess
>>>
>>> # Prepare the model
>>> corpus = api.load("text8")
>>> model = Word2Vec(corpus, workers=3, size=100)
>>> dictionary = Dictionary(corpus)
>>> bow_corpus = [dictionary.doc2bow(document) for document in corpus]
>>> similarity_matrix = model.wv.similarity_matrix(dictionary)
>>> index = SoftCosineSimilarity(bow_corpus, similarity_matrix, num_best=10)
>>>
>>> # Make a query.
>>> query = 'Yummy! Great view of the Bellagio Fountain show.'
>>> # calculate similarity between query and each doc from bow_corpus
>>> sims = index[dictionary.doc2bow(simple_preprocess(query))]

See `Tutorial Notebook
<https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/soft_cosine_tutorial.ipynb>`_
for more examples.

"""
self.corpus = corpus
self.similarity_matrix = similarity_matrix
self.num_best = num_best
self.chunksize = chunksize

# Normalization of features is undesirable, since soft cosine similarity requires special
# normalization using the similarity matrix. Therefore, we would just be normalizing twice,
# increasing the numerical error.
self.normalize = False

# index is simply an array from 0 to size of corpus.
self.index = numpy.arange(len(corpus))

def __len__(self):
return len(self.corpus)

def get_similarities(self, query):
"""
**Do not use this function directly; use the self[query] syntax instead.**
"""
if isinstance(query, numpy.ndarray):
# Convert document indexes to actual documents.
query = [self.corpus[i] for i in query]

if not query or not isinstance(query[0], list):
query = [query]

n_queries = len(query)
result = []
for qidx in range(n_queries):
# Compute similarity for each query.
qresult = [matutils.softcossim(document, query[qidx], self.similarity_matrix)
for document in self.corpus]
qresult = numpy.array(qresult)

# Append single query result to list of all results.
result.append(qresult)

if len(result) == 1:
# Only one query.
result = result[0]
else:
result = numpy.array(result)

return result

def __str__(self):
return "%s<%i docs, %i features>" % (self.__class__.__name__, len(self), self.similarity_matrix.shape[0])


class WmdSimilarity(interfaces.SimilarityABC):
"""
Document similarity (like MatrixSimilarity) that uses the negative of WMD
Expand Down Expand Up @@ -605,7 +707,7 @@ def __init__(self, corpus, w2v_model, num_best=None, normalize_w2v_and_replace=T
self.normalize = False

# index is simply an array from 0 to size of corpus.
self.index = numpy.array(range(len(corpus)))
self.index = numpy.arange(len(corpus))

if normalize_w2v_and_replace:
# Normalize vectors in word2vec class to length 1.
Expand All @@ -622,7 +724,7 @@ def get_similarities(self, query):
# Convert document indexes to actual documents.
query = [self.corpus[i] for i in query]

if not isinstance(query[0], list):
if not query or not isinstance(query[0], list):
query = [query]

n_queries = len(query)
Expand Down
28 changes: 28 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

from gensim.corpora import Dictionary
from gensim.models import KeyedVectors as EuclideanKeyedVectors
from gensim.test.utils import datapath

Expand All @@ -26,6 +27,33 @@ def setUp(self):
self.vectors = EuclideanKeyedVectors.load_word2vec_format(
datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64)

def similarity_matrix(self):
"""Test similarity_matrix returns expected results."""

corpus = [["government", "denied", "holiday"], ["holiday", "slowing", "hollingworth"]]
dictionary = Dictionary(corpus)
corpus = [dictionary.doc2bow(document) for document in corpus]

# checking symmetry and the existence of ones on the diagonal
similarity_matrix = self.similarity_matrix(corpus, dictionary).todense()
self.assertTrue((similarity_matrix.T == similarity_matrix).all())
self.assertTrue((np.diag(similarity_matrix) == similarity_matrix).all())

# checking that thresholding works as expected
similarity_matrix = self.similarity_matrix(corpus, dictionary, threshold=0.45).todense()
self.assertEquals(18, np.sum(similarity_matrix == 0))

# checking that exponent works as expected
similarity_matrix = self.similarity_matrix(corpus, dictionary, exponent=1.0).todense()
self.assertAlmostEqual(9.5788956, np.sum(similarity_matrix))

# checking that nonzero_limit works as expected
similarity_matrix = self.similarity_matrix(corpus, dictionary, nonzero_limit=4).todense()
self.assertEquals(4, np.sum(similarity_matrix == 0))

similarity_matrix = self.similarity_matrix(corpus, dictionary, nonzero_limit=3).todense()
self.assertEquals(20, np.sum(similarity_matrix == 0))

def test_most_similar(self):
"""Test most_similar returns expected results."""
expected = [
Expand Down
Loading