Skip to content

Commit

Permalink
Loading fastText models using only bin file + travis hotfix (#1341)
Browse files Browse the repository at this point in the history
* french wiki issue resolved

* bin and vec mismatch handled

* added test from bin only loading

* [WIP] loading bin only

* word vec from its ngrams

* [WIP] word vec from ngrams

* [WIP] getting syn0 from all n-grams

* [TDD] test comparing word vector from bin_only and default loading

* cleaned up test code

* added docstring for bin_only

* resolved wiki.fr issue

* pep8 fixes

* default bin file loading only

* logging info modified plus changes a/c review

* removed unused code in fasttext.py

* removed unused codes and vec files from test

* added lee_fasttext vec files again

* re-added removed files and unused codes

* added file name in logging info

* removing unused load_word2vec_format code

* updated logging info and comments

* input file name with or without .bin both accepted

* resolved typo mistake

* test for file name

* minor change to input filename handling in ft wrapper

* changes to logging and assert messages, pep8 fixes

* removes redundant .vec files

* fixes utf8 bug in flake8_diff.sh script
  • Loading branch information
prakhar2b authored and menshikh-iv committed Jun 28, 2017
1 parent 7ccaabc commit e1d89c2
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 2,132 deletions.
5 changes: 3 additions & 2 deletions continuous_integration/travis/flake8_diff.sh
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ echo -e '\nRunning flake8 on the diff in the range' "$COMMIT_RANGE" \
echo '--------------------------------------------------------------------------------'

# We ignore files from sklearn/externals.
# Excluding vec files since they contain non-utf8 content and flake8 raises exception for non-utf8 input
# We need the following command to exit with 0 hence the echo in case
# there is no match
MODIFIED_FILES="$(git diff --name-only $COMMIT_RANGE || echo "no_match")"
MODIFIED_FILES="$(git diff --name-only $COMMIT_RANGE -- . ':(exclude)*.vec' || echo "no_match")"

check_files() {
files="$1"
Expand All @@ -133,6 +134,6 @@ check_files() {
if [[ "$MODIFIED_FILES" == "no_match" ]]; then
echo "No file has been modified"
else
check_files "$(echo "$MODIFIED_FILES" )" "--ignore=E501,E731,E12,W503 --exclude=*.sh,*.md,*.yml,*.rst,*.ipynb,*.txt,*.csv,*.vec,Dockerfile*"
check_files "$(echo "$MODIFIED_FILES" )" "--ignore=E501,E731,E12,W503 --exclude=*.sh,*.md,*.yml,*.rst,*.ipynb,*.txt,*.csv,Dockerfile*"
fi
echo -e "No problem detected by flake8\n"
75 changes: 56 additions & 19 deletions gensim/models/wrappers/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import numpy as np
from numpy import float32 as REAL, sqrt, newaxis
from gensim import utils
from gensim.models.keyedvectors import KeyedVectors
from gensim.models.keyedvectors import KeyedVectors, Vocab
from gensim.models.word2vec import Word2Vec

from six import string_types
Expand Down Expand Up @@ -222,10 +222,6 @@ def save(self, *args, **kwargs):
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'syn0_all_norm'])
super(FastText, self).save(*args, **kwargs)

@classmethod
def load_word2vec_format(cls, *args, **kwargs):
return FastTextKeyedVectors.load_word2vec_format(*args, **kwargs)

@classmethod
def load_fasttext_format(cls, model_file, encoding='utf8'):
"""
Expand All @@ -235,13 +231,17 @@ def load_fasttext_format(cls, model_file, encoding='utf8'):
with a model loaded this way, though you can query for word similarity etc.
`model_file` is the path to the FastText output files.
FastText outputs two training files - `/path/to/train.vec` and `/path/to/train.bin`
Expected value for this example: `/path/to/train`
FastText outputs two model files - `/path/to/model.vec` and `/path/to/model.bin`
Expected value for this example: `/path/to/model` or `/path/to/model.bin`,
as gensim requires only `.bin` file to load entire fastText model.
"""
model = cls()
model.wv = cls.load_word2vec_format('%s.vec' % model_file, encoding=encoding)
model.load_binary_data('%s.bin' % model_file, encoding=encoding)
if not model_file.endswith('.bin'):
model_file += '.bin'
model.file_name = model_file
model.load_binary_data(encoding=encoding)
return model

@classmethod
Expand All @@ -254,9 +254,9 @@ def delete_training_files(cls, model_file):
logger.debug('Training files %s not found when attempting to delete', model_file)
pass

def load_binary_data(self, model_binary_file, encoding='utf8'):
def load_binary_data(self, encoding='utf8'):
"""Loads data from the output binary file created by FastText training"""
with utils.smart_open(model_binary_file, 'rb') as f:
with utils.smart_open(self.file_name, 'rb') as f:
self.load_model_params(f)
self.load_dict(f, encoding=encoding)
self.load_vectors(f)
Expand Down Expand Up @@ -287,12 +287,12 @@ def load_model_params(self, file_handle):
def load_dict(self, file_handle, encoding='utf8'):
vocab_size, nwords, _ = self.struct_unpack(file_handle, '@3i')
# Vocab stored by [Dictionary::save](https://github.com/facebookresearch/fastText/blob/master/src/dictionary.cc)
assert len(self.wv.vocab) == nwords, 'mismatch between vocab sizes'
assert len(self.wv.vocab) == vocab_size, 'mismatch between vocab sizes'
logger.info("loading %s words for fastText model from %s", vocab_size, self.file_name)

self.struct_unpack(file_handle, '@1q') # number of tokens
if self.new_format:
pruneidx_size, = self.struct_unpack(file_handle, '@q')
for i in range(nwords):
for i in range(vocab_size):
word_bytes = b''
char_byte = file_handle.read(1)
# Read vocab word
Expand All @@ -301,8 +301,26 @@ def load_dict(self, file_handle, encoding='utf8'):
char_byte = file_handle.read(1)
word = word_bytes.decode(encoding)
count, _ = self.struct_unpack(file_handle, '@qb')
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index'
self.wv.vocab[word].count = count

if i == nwords and i < vocab_size:
# To handle the error in pretrained vector wiki.fr (French).
# For more info : https://github.com/facebookresearch/fastText/issues/218

assert word == "__label__", (
'mismatched vocab_size ({}) and nwords ({}), extra word "{}"'.format(vocab_size, nwords, word))
continue # don't add word to vocab

self.wv.vocab[word] = Vocab(index=i, count=count)
self.wv.index2word.append(word)

assert len(self.wv.vocab) == nwords, (
'mismatch between final vocab size ({} words), '
'and expected number of words ({} words)'.format(len(self.wv.vocab), nwords))
if len(self.wv.vocab) != vocab_size:
# expecting to log this warning only for pretrained french vector, wiki.fr
logger.warning(
"mismatch between final vocab size (%s words), and expected vocab size (%s words)",
len(self.wv.vocab), vocab_size)

if self.new_format:
for j in range(pruneidx_size):
Expand All @@ -313,7 +331,8 @@ def load_vectors(self, file_handle):
self.struct_unpack(file_handle, '@?') # bool quant_input in fasttext.cc
num_vectors, dim = self.struct_unpack(file_handle, '@2q')
# Vectors stored by [Matrix::save](https://github.com/facebookresearch/fastText/blob/master/src/matrix.cc)
assert self.vector_size == dim, 'mismatch between model sizes'
assert self.vector_size == dim, (
'mismatch between vector size in model params ({}) and model vectors ({})'.format(self.vector_size, dim))
float_size = struct.calcsize('@f')
if float_size == 4:
dtype = np.dtype(np.float32)
Expand All @@ -324,7 +343,9 @@ def load_vectors(self, file_handle):
self.wv.syn0_all = np.fromfile(file_handle, dtype=dtype, count=num_vectors * dim)
self.wv.syn0_all = self.wv.syn0_all.reshape((num_vectors, dim))
assert self.wv.syn0_all.shape == (self.bucket + len(self.wv.vocab), self.vector_size), \
'mismatch between weight matrix shape and vocab/model size'
'mismatch between actual weight matrix shape {} and expected shape {}'.format(
self.wv.syn0_all.shape, (self.bucket + len(self.wv.vocab), self.vector_size))

self.init_ngrams()

def struct_unpack(self, file_handle, fmt):
Expand All @@ -340,8 +361,12 @@ def init_ngrams(self):
"""
self.wv.ngrams = {}
all_ngrams = []
for w, v in self.wv.vocab.items():
self.wv.syn0 = np.zeros((len(self.wv.vocab), self.vector_size), dtype=REAL)

for w, vocab in self.wv.vocab.items():
all_ngrams += self.compute_ngrams(w, self.wv.min_n, self.wv.max_n)
self.wv.syn0[vocab.index] += np.array(self.wv.syn0_all[vocab.index])

all_ngrams = set(all_ngrams)
self.num_ngram_vectors = len(all_ngrams)
ngram_indices = []
Expand All @@ -351,6 +376,18 @@ def init_ngrams(self):
self.wv.ngrams[ngram] = i
self.wv.syn0_all = self.wv.syn0_all.take(ngram_indices, axis=0)

ngram_weights = self.wv.syn0_all

logger.info("loading weights for %s words for fastText model from %s", len(self.wv.vocab), self.file_name)

for w, vocab in self.wv.vocab.items():
word_ngrams = self.compute_ngrams(w, self.wv.min_n, self.wv.max_n)
for word_ngram in word_ngrams:
self.wv.syn0[vocab.index] += np.array(ngram_weights[self.wv.ngrams[word_ngram]])

self.wv.syn0[vocab.index] /= (len(word_ngrams) + 1)
logger.info("loaded %s weight matrix for fastText model from %s", self.wv.syn0.shape, self.file_name)

@staticmethod
def compute_ngrams(word, min_n, max_n):
ngram_indices = []
Expand Down
172 changes: 0 additions & 172 deletions gensim/test/test_data/cp852_fasttext.vec

This file was deleted.

Loading

0 comments on commit e1d89c2

Please sign in to comment.