Skip to content

Commit

Permalink
Fix deprecations in SoftCosineSimilarity (#2940)
Browse files Browse the repository at this point in the history
* Remove deprecated Soft Cosine Measure parameters, functions, and tests.

Here is a detailed list of the deprecations:
- Parameter `positive_definite` of `SparseTermSimilarityMatrix` has been
  renamed to `dominant`. Test `test_positive_definite` has been removed.
- Parameter `similarity_matrix` of `SoftCosineSimilarity` no longer
  accepts unencapsulated sparse matrices.
- Parameter `normalized` of `SparseTermSimilarityMatrix.inner_product`
  no longer accepts booleans.
- Function `matutils.softcossim` has been superseded by method
  `SparseTermSimilarityMatrix.inner_product`. Tests in
  `TestSoftCosineSimilarity` have been removed.

* Remove unused imports

* Fix additional warnings from the CI test suite

* Update CHANGELOG.md

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
Witiko and mpenkov committed Sep 16, 2020
1 parent bb947b3 commit 09b7e94
Show file tree
Hide file tree
Showing 12 changed files with 13 additions and 169 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This release contains a major refactoring.
* No more wheels for x32 platforms (if you need x32 binaries, please build them yourself).
(__[menshikh-iv](https://github.com/menshikh-iv)__, [#6](https://github.com/RaRe-Technologies/gensim-wheels/pull/6))
* Speed up random number generation in word2vec model (PR [#2864](https://github.com/RaRe-Technologies/gensim/pull/2864), __[@zygm0nt](https://github.com/zygm0nt)__)
* Fix deprecations in SoftCosineSimilarity (PR [#2940](https://github.com/RaRe-Technologies/gensim/pull/2940), __[@Witiko](https://github.com/Witiko)__)
* Remove Keras dependency (PR [#2937](https://github.com/RaRe-Technologies/gensim/pull/2937), __[@piskvorky](https://github.com/piskvorky)__)

### :books: Tutorial and doc improvements
Expand Down
83 changes: 3 additions & 80 deletions gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from __future__ import with_statement


from itertools import chain
import logging
import math

from gensim import utils
from gensim.utils import deprecated

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -193,9 +191,9 @@ def pad(mat, padrow, padcol):
if padcol < 0:
padcol = 0
rows, cols = mat.shape
return np.bmat([
[mat, np.matrix(np.zeros((rows, padcol)))],
[np.matrix(np.zeros((padrow, cols + padcol)))],
return np.block([
[mat, np.zeros((rows, padcol))],
[np.zeros((padrow, cols + padcol))],
])


Expand Down Expand Up @@ -819,81 +817,6 @@ def cossim(vec1, vec2):
return result


@deprecated(
"Function will be removed in 4.0.0, use "
"gensim.similarities.termsim.SparseTermSimilarityMatrix.inner_product instead")
def softcossim(vec1, vec2, similarity_matrix):
"""Get Soft Cosine Measure between two vectors given a term similarity matrix.
Return Soft Cosine Measure between two sparse vectors given a sparse term similarity matrix
in the :class:`scipy.sparse.csc_matrix` format. The similarity is a number between `<-1.0, 1.0>`,
higher is more similar.
Notes
-----
Soft Cosine Measure was perhaps first defined by `Grigori Sidorov et al.,
"Soft Similarity and Soft Cosine Measure: Similarity of Features in Vector Space Model"
<http://www.cys.cic.ipn.mx/ojs/index.php/CyS/article/view/2043/1921>`_.
Parameters
----------
vec1 : list of (int, float)
A query vector in the BoW format.
vec2 : list of (int, float)
A document vector in the BoW format.
similarity_matrix : {:class:`scipy.sparse.csc_matrix`, :class:`scipy.sparse.csr_matrix`}
A term similarity matrix. If the matrix is :class:`scipy.sparse.csr_matrix`, it is going
to be transposed. If you rely on the fact that there is at most a constant number of
non-zero elements in a single column, it is your responsibility to ensure that the matrix
is symmetric.
Returns
-------
`similarity_matrix.dtype`
The Soft Cosine Measure between `vec1` and `vec2`.
Raises
------
ValueError
When the term similarity matrix is in an unknown format.
See Also
--------
:meth:`gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.similarity_matrix`
A term similarity matrix produced from term embeddings.
:class:`gensim.similarities.docsim.SoftCosineSimilarity`
A class for performing corpus-based similarity queries with Soft Cosine Measure.
"""
if not isinstance(similarity_matrix, scipy.sparse.csc_matrix):
if isinstance(similarity_matrix, scipy.sparse.csr_matrix):
similarity_matrix = similarity_matrix.T
else:
raise ValueError('unknown similarity matrix format')

if not vec1 or not vec2:
return 0.0

vec1 = dict(vec1)
vec2 = dict(vec2)
word_indices = sorted(set(chain(vec1, vec2)))
dtype = similarity_matrix.dtype
vec1 = np.fromiter((vec1[i] if i in vec1 else 0 for i in word_indices), dtype=dtype, count=len(word_indices))
vec2 = np.fromiter((vec2[i] if i in vec2 else 0 for i in word_indices), dtype=dtype, count=len(word_indices))
dense_matrix = similarity_matrix[[[i] for i in word_indices], word_indices].todense()
vec1len = vec1.T.dot(dense_matrix).dot(vec1)[0, 0]
vec2len = vec2.T.dot(dense_matrix).dot(vec2)[0, 0]

assert \
vec1len > 0.0 and vec2len > 0.0, \
u"sparse documents must not contain any explicit zero entries and the similarity matrix S " \
u"must satisfy x^T * S * x > 0 for any nonzero bag-of-words vector x."

result = vec1.T.dot(dense_matrix).dot(vec2)[0, 0]
result /= math.sqrt(vec1len) * math.sqrt(vec2len) # rescale by vector lengths
return np.clip(result, -1.0, 1.0)


def isbow(vec):
"""Checks if a vector is in the sparse Gensim bag-of-words format.
Expand Down
2 changes: 1 addition & 1 deletion gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ def save_word2vec_format(self, fname, fvocab=None, binary=False, total_vec=None,
row = self[key]
if binary:
row = row.astype(REAL)
fout.write(utils.to_utf8(prefix + str(key)) + b" " + row.tostring())
fout.write(utils.to_utf8(prefix + str(key)) + b" " + row.tobytes())
else:
fout.write(utils.to_utf8("%s%s %s\n" % (prefix, str(key), ' '.join(repr(val) for val in row))))

Expand Down
3 changes: 1 addition & 2 deletions gensim/models/wrappers/wordrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ def load_wordrank_model(cls, model_file, vocab_file=None, context_file=None, sor
If 1 - use ensemble of word and context vectors.
"""
glove2word2vec(model_file, model_file + '.w2vformat')
model = cls.load_word2vec_format('%s.w2vformat' % model_file)
model = cls.load_word2vec_format(model_file, binary=False, no_header=True)
if ensemble and context_file:
model.ensemble_embedding(model_file, context_file)
if sorted_vocab and vocab_file:
Expand Down
9 changes: 1 addition & 8 deletions gensim/similarities/docsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
import scipy.sparse

from gensim import interfaces, utils, matutils
from .termsim import SparseTermSimilarityMatrix
from six.moves import map, range, zip


Expand Down Expand Up @@ -931,13 +930,7 @@ def __init__(self, corpus, similarity_matrix, num_best=None, chunksize=256):
A term similarity index that computes cosine similarities between word embeddings.
"""
if scipy.sparse.issparse(similarity_matrix):
logger.warn(
"Support for passing an unencapsulated sparse matrix will be removed in 4.0.0, pass "
"a SparseTermSimilarityMatrix instance instead")
self.similarity_matrix = SparseTermSimilarityMatrix(similarity_matrix)
else:
self.similarity_matrix = similarity_matrix
self.similarity_matrix = similarity_matrix

self.corpus = corpus
self.num_best = num_best
Expand Down
20 changes: 1 addition & 19 deletions gensim/similarities/termsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from itertools import chain
import logging
from math import sqrt
import warnings

import numpy as np
from six.moves import range
Expand Down Expand Up @@ -457,8 +456,6 @@ class SparseTermSimilarityMatrix(SaveLoad):
sparse term similarity matrix. If None, then no limit will be imposed.
dtype : numpy.dtype, optional
The data type of the sparse term similarity matrix.
positive_definite: bool or None, optional
A deprecated alias for dominant.
Attributes
----------
Expand All @@ -472,14 +469,7 @@ class SparseTermSimilarityMatrix(SaveLoad):
"""
def __init__(self, source, dictionary=None, tfidf=None, symmetric=True, dominant=False,
nonzero_limit=100, dtype=np.float32, positive_definite=None):

if positive_definite is not None:
warnings.warn(
'Parameter positive_definite will be removed in 4.0.0, use dominant instead',
category=DeprecationWarning,
)
dominant = positive_definite
nonzero_limit=100, dtype=np.float32):

if not sparse.issparse(source):
index = source
Expand Down Expand Up @@ -529,14 +519,6 @@ def inner_product(self, X, Y, normalized=(False, False)):
if not X or not Y:
return self.matrix.dtype.type(0.0)

if normalized in (True, False):
warnings.warn(
'Boolean parameter normalized will be removed in 4.0.0, use '
'normalized=(%s, %s) instead of normalized=%s' % tuple([normalized] * 3),
category=DeprecationWarning,
)
normalized = (normalized, normalized)

normalized_X, normalized_Y = normalized
valid_normalized_values = (True, False, 'maintain')

Expand Down
3 changes: 2 additions & 1 deletion gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,8 @@ def test_in_vocab(self):

def test_out_of_vocab(self):
model = train_gensim(bucket=0)
self.assertRaises(KeyError, model.wv.word_vec, 'streamtrain')
with self.assertRaises(KeyError):
model.wv.get_vector('streamtrain')

def test_cbow_neg(self):
"""See `gensim.test.test_word2vec.TestWord2VecModel.test_cbow_neg`."""
Expand Down
2 changes: 1 addition & 1 deletion gensim/test/test_lsimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def testTransformFloat32(self):
def testCorpusTransform(self):
"""Test lsi[corpus] transformation."""
model = self.model
got = np.vstack(matutils.sparse2full(doc, 2) for doc in model[self.corpus])
got = np.vstack([matutils.sparse2full(doc, 2) for doc in model[self.corpus]])
expected = np.array([
[0.65946639, 0.14211544],
[2.02454305, -0.42088759],
Expand Down
2 changes: 1 addition & 1 deletion gensim/test/test_phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def testBigramConstructionFromArray(self):
bigram1_seen = False
bigram2_seen = False

for s in self.bigram[np.array(self.sentences)]:
for s in self.bigram[np.array(self.sentences, dtype=object)]:
if not bigram1_seen and self.bigram1 in s:
bigram1_seen = True
if not bigram2_seen and self.bigram2 in s:
Expand Down
23 changes: 0 additions & 23 deletions gensim/test/test_similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,29 +974,6 @@ def test_dominant(self):
[0.0, 0.0, 0.0, 0.0, 1.0]])
self.assertTrue(numpy.all(expected_matrix == matrix))

def test_positive_definite(self):
"""Test the positive_definite parameter of the matrix constructor."""
negative_index = UniformTermSimilarityIndex(self.dictionary, term_similarity=-0.5)
matrix = SparseTermSimilarityMatrix(
negative_index, self.dictionary, nonzero_limit=2).matrix.todense()
expected_matrix = numpy.array([
[1.0, -.5, -.5, 0.0, 0.0],
[-.5, 1.0, 0.0, -.5, 0.0],
[-.5, 0.0, 1.0, 0.0, 0.0],
[0.0, -.5, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0]])
self.assertTrue(numpy.all(expected_matrix == matrix))

matrix = SparseTermSimilarityMatrix(
negative_index, self.dictionary, nonzero_limit=2, positive_definite=True).matrix.todense()
expected_matrix = numpy.array([
[1.0, -.5, 0.0, 0.0, 0.0],
[-.5, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0]])
self.assertTrue(numpy.all(expected_matrix == matrix))

def test_tfidf(self):
"""Test the tfidf parameter of the matrix constructor."""
matrix = SparseTermSimilarityMatrix(
Expand Down
32 changes: 1 addition & 31 deletions gensim/test/test_similarity_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import unittest

from gensim import matutils
from scipy.sparse import csr_matrix, csc_matrix
from scipy.sparse import csr_matrix
import numpy as np
import math
from gensim.corpora.mmcorpus import MmCorpus
Expand Down Expand Up @@ -240,36 +240,6 @@ def test_distributions(self):
self.assertAlmostEqual(expected, result)


class TestSoftCosineSimilarity(unittest.TestCase):
def test_inputs(self):
# checking empty inputs
vec_1 = []
vec_2 = []
similarity_matrix = csc_matrix((0, 0))
result = matutils.softcossim(vec_1, vec_2, similarity_matrix)
expected = 0.0
self.assertEqual(expected, result)

# checking CSR term similarity matrix format
similarity_matrix = csr_matrix((0, 0))
result = matutils.softcossim(vec_1, vec_2, similarity_matrix)
expected = 0.0
self.assertEqual(expected, result)

# checking unknown term similarity matrix format
with self.assertRaises(ValueError):
matutils.softcossim(vec_1, vec_2, np.matrix([]))

def test_distributions(self):
# checking bag of words as inputs
vec_1 = [(0, 1.0), (2, 1.0)] # hello world
vec_2 = [(1, 1.0), (2, 1.0)] # hi world
similarity_matrix = csc_matrix([[1, 0.5, 0], [0.5, 1, 0], [0, 0, 1]])
result = matutils.softcossim(vec_1, vec_2, similarity_matrix)
expected = 0.75
self.assertAlmostEqual(expected, result)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions gensim/test/test_wordrank_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,13 @@ def testLoadWordrankFormat(self):
vocab_size, dim = 76, 50
self.assertEqual(model.vectors.shape, (vocab_size, dim))
self.assertEqual(len(model), vocab_size)
os.remove(self.wr_file + '.w2vformat')

def testEnsemble(self):
"""Test ensemble of two embeddings"""
if not self.wr_path:
return
new_emb = self.test_model.ensemble_embedding(self.wr_file, self.wr_file)
self.assertEqual(new_emb.shape, (76, 50))
os.remove(self.wr_file + '.w2vformat')

def testPersistence(self):
"""Test storing/loading the entire model"""
Expand Down

0 comments on commit 09b7e94

Please sign in to comment.