Skip to content

Commit

Permalink
Ensures FastText returns consistent dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
mcobzarenco committed Oct 24, 2017
1 parent e92b45d commit b648db7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gensim/models/wrappers/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def word_vec(self, word, use_norm=False):
if word in self.vocab:
return super(FastTextKeyedVectors, self).word_vec(word, use_norm)
else:
word_vec = np.zeros(self.syn0_ngrams.shape[1])
word_vec = np.zeros(self.syn0_ngrams.shape[1], dtype=np.float32)
ngrams = compute_ngrams(word, self.min_n, self.max_n)
ngrams = [ng for ng in ngrams if ng in self.ngrams]
if use_norm:
Expand Down
6 changes: 6 additions & 0 deletions gensim/test/test_fasttext_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ def testHash(self):
ft_hash = fasttext.ft_hash('word')
self.assertEqual(ft_hash, 1788406269)

def testConsistentDtype(self):
"""Test that the same dtype is returned for "imputed" words as for words in the vocabulary"""
present_word = self.test_model['night']
missing_word = self.test_model['wordnotpresentinvocabulary']
self.assertEqual(present_word.dtype, missing_word.dtype)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down

0 comments on commit b648db7

Please sign in to comment.