Skip to content

Commit

Permalink
Add allow_inference parameter to KeyedVectors.vectors_for_all
Browse files Browse the repository at this point in the history
  • Loading branch information
Witiko committed May 28, 2021
1 parent 6a8c688 commit 686945a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
27 changes: 19 additions & 8 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,8 @@ def intersect_word2vec_format(self, fname, lockf=0.0, binary=False, encoding='ut
msg=f"merged {overlap_count} vectors into {self.vectors.shape} matrix from {fname}",
)

def vectors_for_all(self, keys: Union[Iterable, Dictionary]) -> 'KeyedVectors':
def vectors_for_all(self, keys: Union[Iterable, Dictionary],
allow_inference: bool = True) -> 'KeyedVectors':
"""Produce vectors for all given keys.
Notes
Expand All @@ -1716,23 +1717,33 @@ def vectors_for_all(self, keys: Union[Iterable, Dictionary]) -> 'KeyedVectors':
----------
keys : {iterable of str, Dictionary}
The keys that will be vectorized.
allow_inference : bool, optional
In subclasses such as :class:`~gensim.models.fasttext.FastTextKeyedVectors`,
vectors for out-of-vocabulary keys (words) may be inferred.
Returns
-------
keyedvectors : :class:`~gensim.models.keyedvectors.KeyedVectors`
Vectors for all the given keys.
"""
def is_key_defined(key) -> bool:
if allow_inference:
return key in self
else:
return key in self.key_to_index

if isinstance(keys, Dictionary):
term_ids = sorted(keys.cfs.items(), key=lambda x: (-x[1], x[0])) # sort by decreasing frequency
vocabulary = [term_id for term_id, freq in term_ids if term_id in self]
else:
vocabulary = (key for key in keys if key in self)
vocabulary = list(OrderedDict.fromkeys(vocabulary)) # deduplicate keys
keys = sorted(keys.cfs.items(), key=lambda x: (-x[1], x[0])) # sort by decreasing collection frequency
keys = (key for key, _ in keys)
keys = filter(is_key_defined, keys) # remove undefined keys
vocabulary = list(OrderedDict.fromkeys(keys)) # deduplicate keys

vocab_size = len(vocabulary)
datatype = self.vectors.dtype
kv = KeyedVectors(self.vector_size, vocab_size, dtype=datatype)
for key in vocabulary:
kv = KeyedVectors(self.vector_size, vocab_size, dtype=datatype) # preallocate new object

for key in vocabulary: # produce and index vectors for all the given keys
weights = self[key]
_add_word_to_kv(kv, None, key, weights, vocab_size)
return kv
Expand Down
23 changes: 21 additions & 2 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def obsolete_testLoadOldModel(self):
self.assertEqual(model.wv.vectors_vocab.shape, (12, 100))
self.assertEqual(model.wv.vectors_ngrams.shape, (2000000, 100))

def test_vectors_for_all(self):
def test_vectors_for_all_with_inference(self):
"""Test vectors_for_all returns expected results."""
words = [
'responding',
Expand All @@ -859,7 +859,7 @@ def test_vectors_for_all(self):
'an out-of-vocabulary word',
'another out-of-vocabulary word',
]
vectors_for_all = self.test_model.wv.vectors_for_all(words)
vectors_for_all = self.test_model.wv.vectors_for_all(words, allow_inference=True)

expected = 5
predicted = len(vectors_for_all)
Expand All @@ -879,6 +879,25 @@ def test_vectors_for_all(self):
)
self.assertGreater(greater_distance, smaller_distance)

def test_vectors_for_all_without_inference(self):
"""Test vectors_for_all returns expected results."""
words = [
'responding',
'approached',
'chairman',
'an out-of-vocabulary word',
'another out-of-vocabulary word',
]
vectors_for_all = self.test_model.wv.vectors_for_all(words, allow_inference=False)

expected = 3
predicted = len(vectors_for_all)
self.assertEqual(expected, predicted)

expected = self.test_model.wv['responding']
predicted = vectors_for_all['responding']
self.assertTrue(np.allclose(expected, predicted))


with open(datapath('toy-data.txt')) as fin:
TOY_SENTENCES = [fin.read().strip().split(' ')]
Expand Down

0 comments on commit 686945a

Please sign in to comment.