Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New KeyedVectors.vectors_for_all method for vectorizing all words in a dictionary #3157

Merged
merged 39 commits into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ab6fb90
Add KeyedVectors.vectors_for_all
Witiko May 25, 2021
98ed69d
Add examples for KeyedVectors.vectors_for_all
Witiko May 25, 2021
be1746b
Support Dictionary in KeyedVectors.vectors_for_all
Witiko May 28, 2021
d81df64
Don't sort keys in KeyedVectors.vectors_for_all, just deduplicate
Witiko May 28, 2021
ef8bea6
Use docstrings in imperative mode (PEP8)
Witiko May 28, 2021
d602018
Guard against KeyError in KeyedVectors.vectors_for_all
Witiko May 28, 2021
13a7ecd
Unit-test dictionary parameter of KeyedVectors.vectors_for_all
Witiko May 28, 2021
6a8c688
Order dictionary by decreasing cfs in KeyedVectors.vectors_for_all
Witiko May 28, 2021
9ebe808
Add allow_inference parameter to KeyedVectors.vectors_for_all
Witiko May 28, 2021
716dc32
Add copy_vecattrs parameter to KeyedVectors.vectors_for_all
Witiko May 28, 2021
77e1889
Move copy_vecattrs tests for KeyedVectors.vectors_for_all
Witiko May 28, 2021
330d5f7
Fix translation of term ids to terms in KeyedVectors.vectors_for_all
Witiko May 28, 2021
8fdda93
Fix a typo in KeyedVectors.vectors_for_all unit test
Witiko May 28, 2021
ba636a2
Do not make assumptions about fake counts in _add_word_to_kv
Witiko May 28, 2021
1a9ea9b
Document that KeyedVectors.vectors_for_all allows arbitrary keys
Witiko May 28, 2021
e5a9a31
Add notes about the behavior of KeyedVectors.vectors_for_all
Witiko May 28, 2021
5eebef0
Properly reference Dictionary in KeyedVectors.vectors_for_all docstring
Witiko May 28, 2021
26baf6d
Make deduplication in KeyedVectors.vectors_for_all a oneliner
Witiko May 31, 2021
98c070e
Remove an unnecessary temporary variable in KeyedVectors.vectors_for_all
Witiko May 31, 2021
8e4d0cf
Make deduplication in KeyedVectors.vectors_for_all a oneliner (cont.)
Witiko May 31, 2021
a4590c1
Add Dictionary.most_common
Witiko May 31, 2021
b14298b
Remove test_vectors_for_all_dictionary unit test
Witiko May 31, 2021
1cf9452
Remove a trailing bracket in an example
Witiko May 31, 2021
9c6f296
Fix unit tests for Dictionary.most_common
Witiko May 31, 2021
e78bfa3
Update an example for SparseTermSimilarityMatrix
Witiko May 31, 2021
32c14c5
Remove Gensim downloader from KeyedVectors.vectors_for_all example
Witiko Jun 22, 2021
9acbcba
Remove include_counts parameter from Dictionary.most_common
Witiko Jun 22, 2021
712ee61
Shorten the KeyedVectors.vectors_for_all example
Witiko Jun 22, 2021
b8625a5
Remove include_counts parameter from Dictionary.most_common (cont.)
Witiko Jun 22, 2021
4aacad2
Use pytest assertion syntax in unit tests
Witiko Jun 22, 2021
a86522c
Remove an unnecessary comment in KeyedVectors.vectors_for_all
Witiko Jun 22, 2021
7ea8337
Remove an unnecessary comment in KeyedVectors.vectors_for_all
Witiko Jun 22, 2021
f08c582
Remove an unnecessary variable in KeyedVectors.vectors_for_all
Witiko Jun 22, 2021
ebc276d
Make the creation of new vocab in KeyedVectors.vectors_for_all explicit
Witiko Jun 22, 2021
3bf7f33
Make AnnoyIndexer use the correct word-vectors in example
Witiko Jun 22, 2021
68b5fc1
Apply suggestions from code review
mpenkov Jun 29, 2021
52e5ee8
Apply suggestions from code review
mpenkov Jun 29, 2021
4dc3756
Update CHANGELOG.md
mpenkov Jun 29, 2021
d319144
Merge branch 'develop' into feature/vectors-for-all
mpenkov Jun 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Changes
* [#3115](https://github.com/RaRe-Technologies/gensim/pull/3115): Make LSI dispatcher CLI param for number of jobs optional, by [@robguinness](https://github.com/robguinness)
* [#3128](https://github.com/RaRe-Technologies/gensim/pull/3128): Materialize and copy the corpus passed to SoftCosineSimilarity, by [@Witiko](https://github.com/Witiko)
* [#3131](https://github.com/RaRe-Technologies/gensim/pull/3131): Added import to Nmf docs, and to models/__init__.py, by [@properGrammar](https://github.com/properGrammar)
* [#3157](https://github.com/RaRe-Technologies/gensim/pull/3157): New KeyedVectors.vectors_for_all method for vectorizing all words in a dictionary, by [@Witiko](https://github.com/Witiko)
* [#3163](https://github.com/RaRe-Technologies/gensim/pull/3163): Optimize word mover distance (WMD) computation, by [@flowlight0](https://github.com/flowlight0)
* [#2965](https://github.com/RaRe-Technologies/gensim/pull/2965): Remove strip_punctuation2 alias of strip_punctuation, by [@sciatro](https://github.com/sciatro)
* [#3169](https://github.com/RaRe-Technologies/gensim/pull/3169): Implement `shrink_windows` argument for Word2Vec., by [@M-Demay](https://github.com/M-Demay)
Expand Down
25 changes: 25 additions & 0 deletions gensim/corpora/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Mapping
import logging
import itertools
from typing import Optional, List, Tuple

from gensim import utils

Expand Down Expand Up @@ -689,6 +690,30 @@ def load_from_text(fname):
result.dfs[wordid] = int(docfreq)
return result

def most_common(self, n: Optional[int] = None) -> List[Tuple[str, int]]:
"""Return a list of the n most common words and their counts from the most common to the least.

Words with equal counts are ordered in the increasing order of their ids.

Parameters
----------
n : int or None, optional
The number of most common words to be returned. If `None`, all words in the dictionary
will be returned. Default is `None`.

Returns
-------
most_common : list of (str, int)
The n most common words and their counts from the most common to the least.

"""
most_common = [
(self[word], count)
for word, count
in sorted(self.cfs.items(), key=lambda x: (-x[1], x[0]))[:n]
]
return most_common

@staticmethod
def from_corpus(corpus, id2word=None):
"""Create :class:`~gensim.corpora.dictionary.Dictionary` from an existing corpus.
Expand Down
65 changes: 65 additions & 0 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
import itertools
import warnings
from numbers import Integral
from typing import Iterable

from numpy import (
dot, float32 as REAL, double, array, zeros, vstack,
Expand Down Expand Up @@ -1689,6 +1690,70 @@ def intersect_word2vec_format(self, fname, lockf=0.0, binary=False, encoding='ut
msg=f"merged {overlap_count} vectors into {self.vectors.shape} matrix from {fname}",
)

def vectors_for_all(self, keys: Iterable, allow_inference: bool = True,
copy_vecattrs: bool = False) -> 'KeyedVectors':
"""Produce vectors for all given keys as a new :class:`KeyedVectors` object.

Notes
-----
The keys will always be deduplicated. For optimal performance, you should not pass entire
corpora to the method. Instead, you should construct a dictionary of unique words in your
corpus:

>>> from collections import Counter
>>> import itertools
>>>
>>> from gensim.models import FastText
>>> from gensim.test.utils import datapath, common_texts
>>>
>>> model_corpus_file = datapath('lee_background.cor') # train word vectors on some corpus
>>> model = FastText(corpus_file=model_corpus_file, vector_size=20, min_count=1)
>>> corpus = common_texts # infer word vectors for words from another corpus
>>> word_counts = Counter(itertools.chain.from_iterable(corpus)) # count words in your corpus
>>> words_by_freq = (k for k, v in word_counts.most_common())
>>> word_vectors = model.wv.vectors_for_all(words_by_freq) # create word-vectors for words in your corpus

Parameters
----------
keys : iterable
The keys that will be vectorized.
allow_inference : bool, optional
In subclasses such as :class:`~gensim.models.fasttext.FastTextKeyedVectors`,
vectors for out-of-vocabulary keys (words) may be inferred. Default is True.
copy_vecattrs : bool, optional
Additional attributes set via the :meth:`KeyedVectors.set_vecattr` method
will be preserved in the produced :class:`KeyedVectors` object. Default is False.
To ensure that *all* the produced vectors will have vector attributes assigned,
you should set `allow_inference=False`.

Returns
-------
keyedvectors : :class:`~gensim.models.keyedvectors.KeyedVectors`
Vectors for all the given keys.

"""
# Pick only the keys that actually exist & deduplicate them.
# We keep the original key order, to improve cache locality, for performance.
vocab, seen = [], set()
for key in keys:
if key not in seen:
seen.add(key)
if key in (self if allow_inference else self.key_to_index):
vocab.append(key)

kv = KeyedVectors(self.vector_size, len(vocab), dtype=self.vectors.dtype)

for key in vocab: # produce and index vectors for all the given keys
weights = self[key]
_add_word_to_kv(kv, None, key, weights, len(vocab))
if copy_vecattrs:
for attr in self.expandos:
try:
kv.set_vecattr(key, attr, self.get_vecattr(key, attr))
except KeyError:
pass
return kv

def _upconvert_old_d2vkv(self):
"""Convert a deserialized older Doc2VecKeyedVectors instance to latest generic KeyedVectors"""
self.vocab = self.doctags
Expand Down
50 changes: 38 additions & 12 deletions gensim/similarities/termsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,28 @@ class WordEmbeddingSimilarityIndex(TermSimilarityIndex):
Computes cosine similarities between word embeddings and retrieves most
similar terms for a given term.

Notes
-----
By fitting the word embeddings to a vocabulary that you will be using, you
can eliminate all out-of-vocabulary (OOV) words that you would otherwise
receive from the `most_similar` method. In subword models such as fastText,
this procedure will also infer word-vectors for words from your vocabulary
that previously had no word-vector.

>>> from gensim.test.utils import common_texts, datapath
>>> from gensim.corpora import Dictionary
>>> from gensim.models import FastText
>>> from gensim.models.word2vec import LineSentence
>>> from gensim.similarities import WordEmbeddingSimilarityIndex
>>>
>>> model = FastText(common_texts, vector_size=20, min_count=1) # train word-vectors on a corpus
>>> different_corpus = LineSentence(datapath('lee_background.cor'))
>>> dictionary = Dictionary(different_corpus) # construct a vocabulary on a different corpus
>>> words = [word for word, count in dictionary.most_common()]
>>> word_vectors = model.wv.vectors_for_all(words) # remove OOV word-vectors and infer word-vectors for new words
>>> assert len(dictionary) == len(word_vectors) # all words from our vocabulary received their word-vectors
>>> termsim_index = WordEmbeddingSimilarityIndex(word_vectors)

Parameters
----------
keyedvectors : :class:`~gensim.models.keyedvectors.KeyedVectors`
Expand Down Expand Up @@ -404,25 +426,29 @@ class SparseTermSimilarityMatrix(SaveLoad):

Examples
--------
>>> from gensim.test.utils import common_texts
>>> from gensim.test.utils import common_texts as corpus, datapath
>>> from gensim.corpora import Dictionary
>>> from gensim.models import Word2Vec
>>> from gensim.similarities import SoftCosineSimilarity, SparseTermSimilarityMatrix, WordEmbeddingSimilarityIndex
>>> from gensim.similarities.index import AnnoyIndexer
>>> from scikits.sparse.cholmod import cholesky
>>>
>>> model = Word2Vec(common_texts, vector_size=20, min_count=1) # train word-vectors
>>> annoy = AnnoyIndexer(model, num_trees=2) # use annoy for faster word similarity lookups
>>> termsim_index = WordEmbeddingSimilarityIndex(model.wv, kwargs={'indexer': annoy})
>>> dictionary = Dictionary(common_texts)
>>> bow_corpus = [dictionary.doc2bow(document) for document in common_texts]
>>> similarity_matrix = SparseTermSimilarityMatrix(termsim_index, dictionary, symmetric=True, dominant=True)
>>> docsim_index = SoftCosineSimilarity(bow_corpus, similarity_matrix, num_best=10)
>>> model_corpus_file = datapath('lee_background.cor')
>>> model = Word2Vec(corpus_file=model_corpus_file, vector_size=20, min_count=1) # train word-vectors
>>>
>>> query = 'graph trees computer'.split() # make a query
>>> sims = docsim_index[dictionary.doc2bow(query)] # calculate similarity of query to each doc from bow_corpus
>>> dictionary = Dictionary(corpus)
>>> tfidf = TfidfModel(dictionary=dictionary)
>>> words = [word for word, count in dictionary.most_common()]
>>> word_vectors = model.wv.vectors_for_all(words, allow_inference=False) # produce vectors for words in corpus
>>>
>>> indexer = AnnoyIndexer(word_vectors, num_trees=2) # use Annoy for faster word similarity lookups
>>> termsim_index = WordEmbeddingSimilarityIndex(word_vectors, kwargs={'indexer': indexer})
>>> similarity_matrix = SparseTermSimilarityMatrix(termsim_index, dictionary, tfidf) # compute word similarities
>>>
>>> word_embeddings = cholesky(similarity_matrix.matrix).L() # obtain word embeddings from similarity matrix
>>> tfidf_corpus = tfidf[[dictionary.doc2bow(document) for document in common_texts]]
>>> docsim_index = SoftCosineSimilarity(tfidf_corpus, similarity_matrix, num_best=10) # index tfidf_corpus
>>>
>>> query = 'graph trees computer'.split() # make a query
>>> sims = docsim_index[dictionary.doc2bow(query)] # find the ten closest documents from tfidf_corpus

Check out `the Gallery <https://radimrehurek.com/gensim/auto_examples/tutorials/run_scm.html>`_
for more examples.
Expand Down
12 changes: 12 additions & 0 deletions gensim/test/test_corpora_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,18 @@ def test_patch_with_special_tokens(self):
self.assertNotIn((1, 1), d.doc2bow(corpus_with_special_tokens[0]))
self.assertIn((1, 1), d.doc2bow(corpus_with_special_tokens[1]))

def test_most_common_with_n(self):
texts = [['human', 'human', 'human', 'computer', 'computer', 'interface', 'interface']]
d = Dictionary(texts)
expected = [('human', 3), ('computer', 2)]
assert d.most_common(n=2) == expected

def test_most_common_without_n(self):
texts = [['human', 'human', 'human', 'computer', 'computer', 'interface', 'interface']]
d = Dictionary(texts)
expected = [('human', 3), ('computer', 2), ('interface', 2)]
assert d.most_common(n=None) == expected


# endclass TestDictionary

Expand Down
48 changes: 48 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,54 @@ def obsolete_testLoadOldModel(self):
self.assertEqual(model.wv.vectors_vocab.shape, (12, 100))
self.assertEqual(model.wv.vectors_ngrams.shape, (2000000, 100))

def test_vectors_for_all_with_inference(self):
"""Test vectors_for_all can infer new vectors."""
words = [
'responding',
'approached',
'chairman',
'an out-of-vocabulary word',
'another out-of-vocabulary word',
]
vectors_for_all = self.test_model.wv.vectors_for_all(words)

expected = 5
predicted = len(vectors_for_all)
assert expected == predicted

expected = self.test_model.wv['responding']
predicted = vectors_for_all['responding']
assert np.allclose(expected, predicted)

smaller_distance = np.linalg.norm(
vectors_for_all['an out-of-vocabulary word']
- vectors_for_all['another out-of-vocabulary word']
)
greater_distance = np.linalg.norm(
vectors_for_all['an out-of-vocabulary word']
- vectors_for_all['responding']
)
assert greater_distance > smaller_distance

def test_vectors_for_all_without_inference(self):
"""Test vectors_for_all does not infer new vectors when prohibited."""
words = [
'responding',
'approached',
'chairman',
'an out-of-vocabulary word',
'another out-of-vocabulary word',
]
vectors_for_all = self.test_model.wv.vectors_for_all(words, allow_inference=False)

expected = 3
predicted = len(vectors_for_all)
assert expected == predicted

expected = self.test_model.wv['responding']
predicted = vectors_for_all['responding']
assert np.allclose(expected, predicted)


@pytest.mark.parametrize('shrink_windows', [True, False])
def test_cbow_hs_training(shrink_windows):
Expand Down
37 changes: 37 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,43 @@ def test_most_similar(self):
predicted = [result[0] for result in self.vectors.most_similar('war', topn=5)]
self.assertEqual(expected, predicted)

def test_vectors_for_all_list(self):
"""Test vectors_for_all returns expected results with a list of keys."""
words = [
'conflict',
'administration',
'terrorism',
'an out-of-vocabulary word',
'another out-of-vocabulary word',
]
vectors_for_all = self.vectors.vectors_for_all(words)

expected = 3
predicted = len(vectors_for_all)
assert expected == predicted

expected = self.vectors['conflict']
predicted = vectors_for_all['conflict']
assert np.allclose(expected, predicted)

def test_vectors_for_all_with_copy_vecattrs(self):
"""Test vectors_for_all returns can copy vector attributes."""
words = ['conflict']
vectors_for_all = self.vectors.vectors_for_all(words, copy_vecattrs=True)

expected = self.vectors.get_vecattr('conflict', 'count')
predicted = vectors_for_all.get_vecattr('conflict', 'count')
assert expected == predicted

def test_vectors_for_all_without_copy_vecattrs(self):
"""Test vectors_for_all returns can copy vector attributes."""
words = ['conflict']
vectors_for_all = self.vectors.vectors_for_all(words, copy_vecattrs=False)

not_expected = self.vectors.get_vecattr('conflict', 'count')
predicted = vectors_for_all.get_vecattr('conflict', 'count')
assert not_expected != predicted

def test_most_similar_topn(self):
"""Test most_similar returns correct results when `topn` is specified."""
self.assertEqual(len(self.vectors.most_similar('war', topn=5)), 5)
Expand Down