Skip to content

Commit

Permalink
Fix hashing-trick from FastText.build_vocab. Fix #1765 (#1768)
Browse files Browse the repository at this point in the history
* adds test case to check wv.syn0_ngrams.shape with bucket

* corrects `ngram_hash` in `FastText.init_ngrams`
  • Loading branch information
manneshiva authored and menshikh-iv committed Dec 7, 2017
1 parent 4398a61 commit cf46f69
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def init_ngrams(self, update=False):
ngram_indices = []
new_hash_count = 0
for i, ngram in enumerate(all_ngrams):
ngram_hash = ft_hash(ngram)
ngram_hash = ft_hash(ngram) % self.bucket
if ngram_hash in self.wv.hash2index:
self.wv.ngrams[ngram] = self.wv.hash2index[ngram_hash]
else:
Expand All @@ -160,7 +160,7 @@ def init_ngrams(self, update=False):
logger.info("Number of new ngrams is %d", len(new_ngrams))
new_hash_count = 0
for i, ngram in enumerate(new_ngrams):
ngram_hash = ft_hash(ngram)
ngram_hash = ft_hash(ngram) % self.bucket
if ngram_hash not in self.wv.hash2index:
self.wv.hash2index[ngram_hash] = new_hash_count + self.old_hash2index_len
self.wv.ngrams[ngram] = self.wv.hash2index[ngram_hash]
Expand Down
7 changes: 7 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,13 @@ def test_persistence_word2vec_format(self):
self.assertRaises(DeprecationWarning, FT_gensim.load_word2vec_format, tmpf)
self.assertRaises(NotImplementedError, FastTextKeyedVectors.load_word2vec_format, tmpf)

def test_bucket_ngrams(self):
model = FT_gensim(size=10, min_count=1, bucket=20)
model.build_vocab(sentences)
self.assertEqual(model.wv.syn0_ngrams.shape, (20, 10))
model.build_vocab(new_sentences, update=True)
self.assertEqual(model.wv.syn0_ngrams.shape, (20, 10))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down

0 comments on commit cf46f69

Please sign in to comment.