Skip to content

Latest commit

 

History

History
97 lines (68 loc) · 3.33 KB

README.md

File metadata and controls

97 lines (68 loc) · 3.33 KB

microTransformer

Welcome Illustration

Index

  1. Project Overview
  2. Installation
  3. Data Preparation
  4. Training
  5. Sampling
  6. Credits

Project Overview

A simple Transformer implementantion for training and evaluating a Transformer model. The model can be trained on a single consumer GPU in a few hours. The repository is structured as follows:

├── config.py <- Transformer configuration and vocabulary
|
├── model.py <- Transformer model, sampling and evaluation <300 lines
|
├── train.py <- training the model ~150 lines
|
├── sample.py <- sampling a trained model
│
├── get_data.py <- generates the validation and test data in csv
│
├── data.py <- loading the data and creating batches
│
└── utils.py <- data sentence generation, tokenization and detokenization

microTransformer is based on the raw implementation of the Transformer model, with the encoder and decoder parts. The task we will be working on is the sorting of a sequence of characters. For example, the sequence "ABCB" will be sorted as "ABBC". The vocabulary and the sentence length can be changed in the config.py file. You can caculate the number of possible sequences with calculate_total_possibilities from utils.py.

During the training you will see how the model converges: the loss will decrease and the accuracy will increase.

Metrics

Installation

pip install pandas torch transformers wandb

Dependencies:

Data preparation

The first step is to generate the validation and test data. The training data will be generated on the fly during the training. We need to generate the validation and test data before the training because we will be checking that the training data generated on the fly is not in the validation or test data.

python get_data.py

Training

The training can be started with the following command:

python train.py

This will start the training with the default parameters. You can change the parameters in the config.py file and the train.py file. The training will be logged on wandb. At the end of the training, the model will be saved in a checkpoints folder.

Sampling

Once the model is trained, you can test it on your own sentences! We will evaluate with the sorted command from Python if the model is able to sort the sentence correctly. To sample a sentence, you can use the following command:

python sample.py --model_path <model_checkpoint> --sentence <your_sentence>

For example:

python sample.py --model_path checkpoints/epoch_10.pt --sentence "ABABBBB"

Will output:

Input: ABABBBB
Output: AABBBBB (correct)

Credits