Skip to content

Any2Vec Refactoring Summary

Shiva Manne edited this page Feb 2, 2018 · 4 revisions

Refactoring Any2Vec

This document briefly describes the refactoring task which can help understand/extend/develop/improve this work in future.

Objectives

We specifically target "any2vec" algorithms -- Word2Vec, Doc2Vec, FastText with the following high level objectives:

  1. Cleanup code to make it more modular and DRY.

  2. Make it **easy to **on-board/implement similar algorithms in future (e.g. Poincare).

  3. Maintain backward compatibility.

  4. Provide "Callback" mechanism -- enable the user to apply functions to the *model *at given stages during the training.

  5. Better the code design.

Old Design

Refer to the UML diagram below to get better a view of the old design:

any2vec_old_design.png

Drawbacks/concerns about existing code base:

  • Does not maintain a consistent structure

    • Each algorithm is independently implemented by subclassing other algorithms where it needs to re-use most of the (but not all) methods from the base algorithm -- Doc2Vec, FastText wrapper and native FastText all three subclass from Word2Vec which causes unintended methods to be part of child algorithm. E.g.: load_word2vec_format does not make sense to be called using the native FastText model.

    • Keyed Vectors introduced and integrated for word vectors (Word2Vec, FastText) but not implemented for document vectors(Doc2Vec.docvecs) -- inconsistency.

  • All the algorithm methods are implemented in one huge class -- making it difficult to understand/maintain/develop.

  • Major common functionalities are currently implemented in a single class/method and these are hacked to support other algorithms. For instance, all the code for multithreaded training is present in Word2Vec.train. Another example -- train_sg/cbow_pair of Word2Vec hacked to support native FastText. Thus, adding more *any2vec *algorithms would require more such hacking of "other classes" (with “if-else”s etc.).

  • Contains unintuitive attribute naming -- syn0, doctag_syn0 might be confusing to a new/unfamiliar user.

  • Reasonable scope to improve documentation -- going through current code can be take-consuming for a user focused on adding non-trivial features to existing algorithms.

  • No easy mechanism/way to provide user with useful optional features like saving model after every epoch or logging loss after every batch. I believe loss logging is a highly "asked for" feature which has already been implemented for Word2Vec but is not displayed to the user.

  • A lot of code repetition in Cython parts of the implementation, not as much in Python parts.

New Design

The new design segregates the code for any2vec implementation into the following 4 classes:

  1. Vocabulary Class

Deals explicitly with vocabulary building, pruning, updation and storage -- attributes like vocab , index2word and methods such as scan_vocab, update_vocab, sort_vocab, add_null_word, prune_vocab.

  1. Trainables Class

Contains the trainable weights of the model -- syn1, syn1neg, vectors(syn0) along with the tables/data structure that support training . Methods in this class are responsible for initializing weights and adding weights with additional vocab words -- reset_weights, update_weights.

  1. KeyedVectors Class

Stores *keyed *vectors post training and other data required for methods operating on the trained vectors.

Attributes: vocab, vectors(syn0) and more

Methods: similarity, most_similar, __getitem__ and similar methods

  1. Model Class

This is the main class for the algorithm implementation which initialises and orchestrates the building/training of a model using the above 3 classes.

More specifically, the new code contains the following base classes:

  • BaseAny2Vec -- provides implementation for multithreaded training and acts as an interface for future *any2vec *algorithms.

  • BaseKeyedVectors -- Abstract base class for keyed vectors.

  • BaseVocabBuilder -- Abstract base class for vocabulary building.

  • BaseModelTrainables -- Abstract base class for Trainables class described above(2).

  • Callback -- Abstract base class used to build new callbacks.

  • BaseWordEmbedddingsModel -- Base class which implements common methods specifically useful for word embeddings model.

  • WordEmbeddingsKeyedVectors -- Base class which implements common keyed vectors methods useful for word embeddings (subclassed by FastTextKeyedVectors, Word2VecKeyedVectors).

These bases classes are inherited by the algorithms to provide concrete implementation classes -- Word2Vec, Word2VecVocab, Word2VecTrainables, Word2VecKeyedVectors, FastText, FastTextVocab, FastTextTrainables, FastTextKeyedVectors, Doc2Vec, Doc2VecVocab, Doc2VecTrainables, Doc2VecKeyedVectors and PoincareKeyedVectors.

The UML diagram summarizes and gives a better picture of the new design -- any2vec_new_design.png

Key Details

  • New design follows a modular and consistent structure.

  • BaseAny2Vec class provides the implementation for multithreading (push jobs/data to a queue and workers picking up these jobs for training). Hence, a new algorithm can easily extend to a multi-worker training without having to re-implement this part.

  • The new code would have negligible repetition.

  • Uncomplicated design and clear segregation of responsibilities among classes.

  • Better documentation will result in faster future developments.

  • Callback mechanism would mean more flexibility to the user in terms of having individually serving features (e.g. Calculate accuracy using trained vectors on a downstream task at the end of each epoch).

  • Maintains backward compatibility.

Clone this wiki locally