From 4d0661faf015b4c9f06aa8b38e1db9778e3abe00 Mon Sep 17 00:00:00 2001 From: Pushpankar Date: Mon, 5 Feb 2018 20:36:16 +0530 Subject: [PATCH] Test type conversion --- gensim/test/test_datatype.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/gensim/test/test_datatype.py b/gensim/test/test_datatype.py index f6e78ea036..42b05c4fb5 100644 --- a/gensim/test/test_datatype.py +++ b/gensim/test/test_datatype.py @@ -25,6 +25,8 @@ def load_model(self, datatype): def test_high_precision(self): kv = self.load_model(np.float64) + import pdb + pdb.set_trace() self.assertAlmostEqual(kv['horse.n.01'][0], -0.0008546282343595379) self.assertEqual(kv['horse.n.01'][0].dtype, np.float64) @@ -38,6 +40,16 @@ def test_low_precision(self): self.assertAlmostEqual(kv['horse.n.01'][0], -0.00085449) self.assertEqual(kv['horse.n.01'][0].dtype, np.float16) + def test_type_conversion(self): + path = datapath('test.kv.txt') + binary_path = datapath('test.kv.bin') + model1 = KeyedVectors.load_word2vec_format(path, datatype=np.float16) + model1.save_word2vec_format(binary_path, binary=True) + model2 = KeyedVectors.load_word2vec_format(binary_path, datatype=np.float64, binary=True) + import pdb + pdb.set_trace() + self.assertAlmostEqual(model1["horse.n.01"][0], np.float16(model2["horse.n.01"][0])) + if __name__ == '__main__': logging.root.setLevel(logging.WARNING)