From 29f40604b665d32c301d833f075767d96a8e6b2b Mon Sep 17 00:00:00 2001 From: dcarron Date: Thu, 29 Sep 2022 11:57:53 +0200 Subject: [PATCH] Fixed issue when using add_vector with FastTextKeyedVectors Since Gensim 4.0, 'key' in FastTextKeyedVectors always returns True by design. The proper way to check if a key already exists is with 'key' in FastTextKeyedVectors.key_to_index. --- gensim/models/keyedvectors.py | 2 +- gensim/test/test_fasttext.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 6b31496bb5..d720202631 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -592,7 +592,7 @@ def add_vectors(self, keys, weights, extras=None, replace=False): in_vocab_mask = np.zeros(len(keys), dtype=bool) for idx, key in enumerate(keys): - if key in self: + if key in self.key_to_index: in_vocab_mask[idx] = True # add new entities to the vocab diff --git a/gensim/test/test_fasttext.py b/gensim/test/test_fasttext.py index ecc44a30e4..19040d02e9 100644 --- a/gensim/test/test_fasttext.py +++ b/gensim/test/test_fasttext.py @@ -1782,6 +1782,28 @@ def test_identity(self): self.assertTrue(np.all(np.array([6, 7, 8]) == n[2])) +class FastTextKeyedVectorsTest(unittest.TestCase): + def test_add_vector(self): + wv = FastTextKeyedVectors(vector_size=2, min_n=3, max_n=6, bucket=2000000) + wv.add_vector("test_key", np.array([0, 0])) + + self.assertEqual(wv.key_to_index["test_key"], 0) + self.assertEqual(wv.index_to_key[0], "test_key") + self.assertTrue(np.all(wv.vectors[0] == np.array([0, 0]))) + + def test_add_vectors(self): + wv = FastTextKeyedVectors(vector_size=2, min_n=3, max_n=6, bucket=2000000) + wv.add_vectors(["test_key1", "test_key2"], np.array([[0, 0], [1, 1]])) + + self.assertEqual(wv.key_to_index["test_key1"], 0) + self.assertEqual(wv.index_to_key[0], "test_key1") + self.assertTrue(np.all(wv.vectors[0] == np.array([0, 0]))) + + self.assertEqual(wv.key_to_index["test_key2"], 1) + self.assertEqual(wv.index_to_key[1], "test_key2") + self.assertTrue(np.all(wv.vectors[1] == np.array([1, 1]))) + + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) unittest.main()