Skip to content

Commit

Permalink
KeyedVectors & *2Vec API streamlining, consistency (#2698)
Browse files Browse the repository at this point in the history
* slim low-value warnings

* clarify vectors/vectors_vocab relationship; fix lockf & nonsense ngram-norming confusion

* mv FT, KV tests to right place

* rm deprecations, obsolete refs/tests, delete_temporary_training_data, update usages

* update usages, tests, flake8 cleanup

* expand KeyedVectors to obviate Doc2VecKeyedVectors; upconvert old offset-style doctags

* fix docstring warnings; update usages

* rm unused old plain-python codepaths

* unify class comments under __init__ for consistncy w/ api doc presentation

* name/comment harmonization (rm 'entity', lessen 'word'-centricity)

* table formatting

* return pyemd to linux test env

* split backcompat tests for better resolution

* convert Vocab & related data items to use dataclasses

* rm obsolete Vocab/Trainable/abstract/Wrapper classes, persistent callbacks (bug #2136), outdated tests/warnings; update usages

* tune tests for stability, runtimes; rm auto reruns that hide flakiness

* fix numpy FutureWarning: arrays to stack must be sequence

* (commented-out) deoptimization option

* stronger FB model testing; no _unpack_copy test

* merge redundant methods; rm duplicated imports/defs

* rationalize _lockf, buckets_word behaviors

* rename .docvecs to .dv

* update usages; rm obsolete tests; restore gensim.utils import

* intensify FT tests (more epochs, more buckets)

* flake8-3.8.0 style fixes - but also pin flake8-3.7.9 vs 3.8.0 'output_file' error

* replace vectors_norm with 1d norms

* tighten testParallel

* rm .vocab & 'Vocab' classes; add expandable 'vecattrs'

* update usages (no vocabs)

* enable running inside '-m mtprof' (or cProfile) via explicit unittest.main(module=..)

* faster sample_int reads

* load_word2vec_format(.., no_header=True) to support GLoVe text vectors

* refactor & comment lockf feature; allow single-element lockf

* improve FT comment

* rm deprecated/unneded init_sims calls

* fixes to code style

* flake8: fix overlong lines

* rm stray merge error

* rm duplicated , old nonstandard hash workarounds

* use numpy-recommended PRNG constructor

* add sg to FastTextConfig & consult it; rm remaining broken-hash cruft

* reorg conditional packages for clarity

* comments, names, refactoring, randomization

* Apply suggestions from code review

Co-authored-by: Radim Řehůřek <me@radimrehurek.com>

* fix cruft left from suggestion

* fix numpy-32bit-on-Windows; executable docs

* mv lee_corpus to utils; cleanup

* update poincare for latest KV __init__ signature

* restore word_vec method for proper overriding, but rm usages

* Apply suggestions from code review

Co-authored-by: Radim Řehůřek <me@radimrehurek.com>

* adjust testParallel against failure risk

* intensify training for an occasionally failing test

* clarify word/char ngrams handling; rm outdated comments

* mostly avoid duplciating FastTextConfig fields into locals

* avoid copies/pointers for no-bucket (FT as W2V) case

* rm obsolete test (already skipped & somewhat originally misguided)

* simpler/faster .get(..., default) (avoids exception-catching in has_index_for)

* add default option to get_index; avoid exception in has_index_for

* chained range check

Co-authored-by: Radim Řehůřek <me@radimrehurek.com>

* Update CHANGELOG.md

Co-authored-by: Radim Řehůřek <radimrehurek@seznam.cz>
Co-authored-by: Radim Řehůřek <me@radimrehurek.com>
Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
4 people authored Jul 19, 2020
1 parent 4cdf228 commit c0e0169
Show file tree
Hide file tree
Showing 76 changed files with 5,642 additions and 14,143 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ Changes

## Unreleased

This release contains a major refactoring.

### :+1: Improvements

* KeyedVectors & X2Vec API streamlining, consistency (PR [#2698](https://github.com/RaRe-Technologies/gensim/pull/2698), __[@gojomo](https://github.com/gojomo)__)
* No more wheels for x32 platforms (if you need x32 binaries, please build them yourself).
(__[menshikh-iv](https://github.com/menshikh-iv)__, [#6](https://github.com/RaRe-Technologies/gensim-wheels/pull/6))
* Speed up random number generation in word2vec model (PR [#2864](https://github.com/RaRe-Technologies/gensim/pull/2864), __[@zygm0nt](https://github.com/zygm0nt)__)
Expand Down
2 changes: 0 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ include gensim/models/fasttext_inner.pxd
include gensim/models/fasttext_corpusfile.cpp
include gensim/models/fasttext_corpusfile.pyx

include gensim/models/_utils_any2vec.c
include gensim/models/_utils_any2vec.pyx
include gensim/corpora/_mmreader.c
include gensim/corpora/_mmreader.pyx
include gensim/_matutils.c
Expand Down
9 changes: 0 additions & 9 deletions docs/src/apiref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ Modules:
models/coherencemodel
models/basemodel
models/callbacks
models/utils_any2vec
models/_utils_any2vec
models/word2vec_inner
models/doc2vec_inner
models/fasttext_inner
Expand All @@ -63,13 +61,6 @@ Modules:
models/wrappers/ldavowpalwabbit.rst
models/wrappers/wordrank
models/wrappers/varembed
models/wrappers/fasttext
models/deprecated/doc2vec
models/deprecated/fasttext
models/deprecated/word2vec
models/deprecated/keyedvectors
models/deprecated/fasttext_wrapper
models/base_any2vec
similarities/docsim
similarities/termsim
similarities/index
Expand Down
2 changes: 1 addition & 1 deletion docs/src/auto_examples/tutorials/run_fasttext.rst
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ The example training corpus is a toy corpus, results are not expected to be good
.. code-block:: none
/Volumes/work/workspace/gensim_misha/gensim/models/keyedvectors.py:877: FutureWarning: arrays to stack must be passed as a "sequence" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.
vectors = vstack(self.word_vec(word, use_norm=True) for word in used_words).astype(REAL)
vectors = vstack(self.get_vector(word, use_norm=True) for word in used_words).astype(REAL)
'breakfast'
Expand Down
2 changes: 1 addition & 1 deletion docs/src/auto_examples/tutorials/run_word2vec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ Which of the below does not belong in the sequence?
.. code-block:: none
/home/misha/git/gensim/gensim/models/keyedvectors.py:877: FutureWarning: arrays to stack must be passed as a "sequence" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.
vectors = vstack(self.word_vec(word, use_norm=True) for word in used_words).astype(REAL)
vectors = vstack(self.get_vector(word, use_norm=True) for word in used_words).astype(REAL)
car
Expand Down
9 changes: 0 additions & 9 deletions docs/src/models/_utils_any2vec.rst

This file was deleted.

10 changes: 0 additions & 10 deletions docs/src/models/base_any2vec.rst

This file was deleted.

9 changes: 0 additions & 9 deletions docs/src/models/deprecated/doc2vec.rst

This file was deleted.

10 changes: 0 additions & 10 deletions docs/src/models/deprecated/fasttext.rst

This file was deleted.

10 changes: 0 additions & 10 deletions docs/src/models/deprecated/fasttext_wrapper.rst

This file was deleted.

9 changes: 0 additions & 9 deletions docs/src/models/deprecated/keyedvectors.rst

This file was deleted.

9 changes: 0 additions & 9 deletions docs/src/models/deprecated/word2vec.rst

This file was deleted.

9 changes: 0 additions & 9 deletions docs/src/models/utils_any2vec.rst

This file was deleted.

9 changes: 0 additions & 9 deletions docs/src/models/wrappers/fasttext.rst

This file was deleted.

82 changes: 41 additions & 41 deletions gensim/corpora/sharded_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import logging
import os
import math
import numpy
import scipy.sparse as sparse
import time

import numpy
import scipy.sparse as sparse
from six.moves import range

import gensim
Expand Down Expand Up @@ -263,9 +263,7 @@ def init_shards(self, output_prefix, corpus, shardsize=4096, dtype=_default_dtyp

is_corpus, corpus = gensim.utils.is_corpus(corpus)
if not is_corpus:
raise ValueError(
"Cannot initialize shards without a corpus to read from! (Got corpus type: {0})".format(type(corpus))
)
raise ValueError("Cannot initialize shards without a corpus to read from! Corpus type: %s" % type(corpus))

proposed_dim = self._guess_n_features(corpus)
if proposed_dim != self.dim:
Expand Down Expand Up @@ -360,7 +358,7 @@ def load_shard(self, n):

filename = self._shard_name(n)
if not os.path.isfile(filename):
raise ValueError('Attempting to load nonexistent shard no. {0}'.format(n))
raise ValueError('Attempting to load nonexistent shard no. %s' % n)
shard = gensim.utils.unpickle(filename)

self.current_shard = shard
Expand All @@ -387,11 +385,9 @@ def shard_by_offset(self, offset):
"""
k = int(offset / self.shardsize)
if offset >= self.n_docs:
raise ValueError('Too high offset specified ({0}), available '
'docs: {1}'.format(offset, self.n_docs))
raise ValueError('Too high offset specified (%s), available docs: %s' % (offset, self.n_docs))
if offset < 0:
raise ValueError('Negative offset {0} currently not'
' supported.'.format(offset))
raise ValueError('Negative offset %s currently not supported.' % offset)
return k

def in_current(self, offset):
Expand All @@ -411,7 +407,7 @@ def in_next(self, offset):
"""
if self.current_shard_n == self.n_shards:
return False # There's no next shard.
return (self.offsets[self.current_shard_n + 1] <= offset) and (offset < self.offsets[self.current_shard_n + 2])
return self.offsets[self.current_shard_n + 1] <= offset and offset < self.offsets[self.current_shard_n + 2]

def resize_shards(self, shardsize):
"""
Expand Down Expand Up @@ -440,9 +436,8 @@ def resize_shards(self, shardsize):
if new_stop > self.n_docs:
# Sanity check
assert new_shard_idx == n_new_shards - 1, \
'Shard no. {0} that ends at {1} over last document' \
' ({2}) is not the last projected shard ({3})???' \
''.format(new_shard_idx, new_stop, self.n_docs, n_new_shards)
'Shard no. %r that ends at %r over last document (%r) is not the last projected shard (%r)' % (
new_shard_idx, new_stop, self.n_docs, n_new_shards)
new_stop = self.n_docs

new_shard = self[new_start:new_stop]
Expand All @@ -466,9 +461,9 @@ def resize_shards(self, shardsize):
for old_shard_n, old_shard_name in enumerate(old_shard_names):
os.remove(old_shard_name)
except Exception as e:
logger.error(
'Exception occurred during old shard no. %d removal: %s.\nAttempting to at least move new shards in.',
old_shard_n, str(e)
logger.exception(
'Error during old shard no. %d removal: %s.\nAttempting to at least move new shards in.',
old_shard_n, str(e),
)
finally:
# If something happens with cleaning up - try to at least get the
Expand All @@ -479,7 +474,7 @@ def resize_shards(self, shardsize):
# If something happens when we're in this stage, we're screwed.
except Exception as e:
logger.exception(e)
raise RuntimeError('Resizing completely failed for some reason. Sorry, dataset is probably ruined...')
raise RuntimeError('Resizing completely failed. Sorry, dataset is probably ruined...')
finally:
# Sets the new shard stats.
self.n_shards = n_new_shards
Expand Down Expand Up @@ -524,19 +519,18 @@ def _guess_n_features(self, corpus):
else:
if not self.dim:
raise TypeError(
"Couldn't find number of features, refusing to guess "
"(dimension set to {0}, type of corpus: {1})."
.format(self.dim, type(corpus))
"Couldn't find number of features, refusing to guess. Dimension: %s, corpus: %s)" % (
self.dim, type(corpus),
)
)
else:
logger.warning("Couldn't find number of features, trusting supplied dimension (%d)", self.dim)
n_features = self.dim
logger.warning("Couldn't find number of features, trusting supplied dimension (%d)", self.dim)
n_features = self.dim

if self.dim and n_features != self.dim:
logger.warning(
"Discovered inconsistent dataset dim (%d) and feature count from corpus (%d). "
"Coercing to dimension given by argument.",
self.dim, n_features
self.dim, n_features,
)

return n_features
Expand Down Expand Up @@ -591,7 +585,7 @@ def __getitem__(self, offset):
start = offset.start
stop = offset.stop
if stop > self.n_docs:
raise IndexError('Requested slice offset {0} out of range ({1} docs)'.format(stop, self.n_docs))
raise IndexError('Requested slice offset %s out of range (%s docs)' % (stop, self.n_docs))

# - get range of shards over which to iterate
first_shard = self.shard_by_offset(start)
Expand Down Expand Up @@ -674,21 +668,23 @@ def __getitem__(self, offset):

def __add_to_slice(self, s_result, result_start, result_stop, start, stop):
"""
Add the rows of the current shard from `start` to `stop`
Add rows of the current shard from `start` to `stop`
into rows `result_start` to `result_stop` of `s_result`.
Operation is based on the self.sparse_serialize setting. If the shard
Operation is based on the ``self.sparse_serialize`` setting. If the shard
contents are dense, then s_result is assumed to be an ndarray that
already supports row indices `result_start:result_stop`. If the shard
contents are sparse, assumes that s_result has `result_start` rows
and we should add them up to `result_stop`.
Returns the resulting s_result.
Return the resulting ``s_result``.
"""
if (result_stop - result_start) != (stop - start):
raise ValueError(
'Result start/stop range different than stop/start range (%d - %d vs. %d - %d)'
% (result_start, result_stop, start, stop)
'Result start/stop range different than stop/start range (%s - %s vs. %s - %s)' % (
result_start, result_stop, start, stop,
)
)

# Dense data: just copy using numpy's slice notation
Expand All @@ -699,16 +695,16 @@ def __add_to_slice(self, s_result, result_start, result_stop, start, stop):

# A bit more difficult, we're using a different structure to build the
# result.
else:
if s_result.shape != (result_start, self.dim):
raise ValueError(
'Assuption about sparse s_result shape invalid: {0} expected rows, {1} real rows.'
.format(result_start, s_result.shape[0])
if s_result.shape != (result_start, self.dim):
raise ValueError(
'Assuption about sparse s_result shape invalid: %s expected rows, %s real rows.' % (
result_start, s_result.shape[0],
)
)

tmp_matrix = self.current_shard[start:stop]
s_result = sparse.vstack([s_result, tmp_matrix])
return s_result
tmp_matrix = self.current_shard[start:stop]
s_result = sparse.vstack([s_result, tmp_matrix])
return s_result

def _getitem_format(self, s_result):
if self.sparse_serialization:
Expand Down Expand Up @@ -817,5 +813,9 @@ def serialize(serializer, fname, corpus, id2word=None, index_fname=None, progres
Ignore the parameters id2word, index_fname, progress_cnt, labels
and metadata. They currently do nothing and are here only to
provide a compatible method signature with superclass."""
serializer.save_corpus(fname, corpus, id2word=id2word, progress_cnt=progress_cnt, metadata=metadata, **kwargs)
provide a compatible method signature with superclass.
"""
serializer.save_corpus(
fname, corpus, id2word=id2word, progress_cnt=progress_cnt, metadata=metadata, **kwargs,
)
3 changes: 1 addition & 2 deletions gensim/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .logentropy_model import LogEntropyModel # noqa:F401
from .word2vec import Word2Vec # noqa:F401
from .doc2vec import Doc2Vec # noqa:F401
from .keyedvectors import KeyedVectors, WordEmbeddingSimilarityIndex # noqa:F401
from .keyedvectors import KeyedVectors # noqa:F401
from .ldamulticore import LdaMulticore # noqa:F401
from .phrases import Phrases # noqa:F401
from .normmodel import NormModel # noqa:F401
Expand All @@ -23,7 +23,6 @@
from .translation_matrix import TranslationMatrix, BackMappingTranslationMatrix # noqa:F401

from . import wrappers # noqa:F401
from . import deprecated # noqa:F401

from gensim import interfaces, utils

Expand Down
Loading

0 comments on commit c0e0169

Please sign in to comment.