Skip to content

An implementation of masked language modeling for Pytorch, made as concise and simple as possible

License

Notifications You must be signed in to change notification settings

lucidrains/mlm-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLM (Masked Language Modeling) Pytorch

This repository allows you to quickly setup unsupervised training for your transformer off a corpus of sequence data.

Install

$ pip install mlm-pytorch

Usage

First pip install x-transformers, then run the following example to see what one iteration of the unsupervised training is like

import torch
from torch import nn
from torch.optim import Adam
from mlm_pytorch import MLM

# instantiate the language model

from x_transformers import TransformerWrapper, Encoder

transformer = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

# plugin the language model into the MLM trainer

trainer = MLM(
    transformer,
    mask_token_id = 2,          # the token id reserved for masking
    pad_token_id = 0,           # the token id for padding
    mask_prob = 0.15,           # masking probability for masked language modeling
    replace_prob = 0.90,        # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
    mask_ignore_token_ids = []  # other tokens to exclude from masking, include the [cls] and [sep] here
).cuda()

# optimizer

opt = Adam(trainer.parameters(), lr=3e-4)

# one training step (do this for many steps in a for loop, getting new `data` each time)

data = torch.randint(0, 20000, (8, 1024)).cuda()

loss = trainer(data)
loss.backward()
opt.step()
opt.zero_grad()

# after much training, the model should have improved for downstream tasks

torch.save(transformer, f'./pretrained-model.pt')

Do the above for many steps, and your model should improve.

Citation

@misc{devlin2018bert,
    title   = {BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
    author  = {Jacob Devlin and Ming-Wei Chang and Kenton Lee and Kristina Toutanova},
    year    = {2018},
    eprint  = {1810.04805},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}