Skip to content

Commit

Permalink
Order dictionary by decreasing cfs in KeyedVectors.vectors_for_all
Browse files Browse the repository at this point in the history
  • Loading branch information
Witiko committed May 28, 2021
1 parent 13a7ecd commit 6a8c688
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,7 +1724,8 @@ def vectors_for_all(self, keys: Union[Iterable, Dictionary]) -> 'KeyedVectors':
"""
if isinstance(keys, Dictionary):
vocabulary = [key for key in keys.token2id if key in self]
term_ids = sorted(keys.cfs.items(), key=lambda x: (-x[1], x[0])) # sort by decreasing frequency
vocabulary = [term_id for term_id, freq in term_ids if term_id in self]
else:
vocabulary = (key for key in keys if key in self)
vocabulary = list(OrderedDict.fromkeys(vocabulary)) # deduplicate keys
Expand Down
7 changes: 6 additions & 1 deletion gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_vectors_for_all_dictionary(self):
words = Dictionary([[
'conflict',
'administration',
'terrorism',
'terrorism', 'terrorism',
'an out-of-vocabulary word',
'another out-of-vocabulary word',
]])
Expand All @@ -78,6 +78,11 @@ def test_vectors_for_all_dictionary(self):
predicted = vectors_for_all['conflict']
self.assertTrue(np.allclose(expected, predicted))

# terrorism has the highest frequency, so it should be indexed first
expected = 'terrorism'
predicted = vectors_for_all.index_to_key[0]
self.assertEqual(expected, predicted)

def test_most_similar_topn(self):
"""Test most_similar returns correct results when `topn` is specified."""
self.assertEqual(len(self.vectors.most_similar('war', topn=5)), 5)
Expand Down

0 comments on commit 6a8c688

Please sign in to comment.