Skip to content

gingasan/lemon

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Rethinking Masked Language Modeling for Chinese Spelling Correction

This is the official repo for the ACL 2023 paper Rethinking Masked Language Modeling for Chinese Spelling Correction; AAAI 2024 paper Chinese Spelling Correction as Rephraing Language Model.

Fine-tuning results on some of benchmarks:

EC-LAW EC-MED EC-ODW MCSC
BERT 39.8 22.3 25.0 70.7
MDCSpell-Masked-FT 80.6 69.6 66.9 78.5
Baichuan2-Masked-FT 86.0 73.2 82.6 75.5
ReLM 95.6 89.9 92.3 83.2

==New==

ReLM

ReLM pre-trained model is released. It is a rephrasing language model trained based on bert-base-chinese and 34 million monolingual data.

The main idea is illustrated in the figure below. We concatenate the input and a sequence of mask tokens of the same length as the input, and train the model to rephrase the entire sentence by infilling additional slots, instead of character-to-character tagging. We also apply the masked-fine-tuning technique during training, which masks a proportion of characters in the source sentence. We will not mask source sentence in evaluation stage.

relm-m0.3.bin

Different from BERT-MFT, ReLM is a pure language model, which optimizes the rephrasing language modeling objective instead of sequence tagging.

from autocsc import AutoCSCReLM

model = AutoCSCReLM.from_pretrained("bert-base-chinese",
                                    state_dict=torch.load("relm-m0.3.bin"),
                                    cache_dir="cache")

Monolingual data

We share our used training data for LEMON. It contains 34 million monolingual sentences and we synthesize sentence pairs based on our confusion set in confus.

monolingual-wiki-news-l64

We split the data into 343 sub-files with 100,000 sentences for each. The total size of the .zip file is 1.5G.

Our code supports multiple GPUs now:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --multi_gpu run.py \
  --do_train \
  --do_eval \
  --fp16 \
  --mft

LEMOM

LEMON (large-scale multi-domain dataset with natural spelling errors) is a novel benchmark released with our paper. All test sets are in lemon_v2.

Note: This dataset can only be used for academic research, it cannot be used for commercial purposes.

The other test sets we use in the paper are in sighan_ecspell.

The confusion sets are in confus.

Trained weights

In our paper, we train BERT for 30,000 steps, with the learning rate 5e-5 and batch size 8192. The backbone model is bert-base-chinese. We share our trained model weights to facilitate future research. We welcome researchers to develop better ones based on our models.

BERT-finetune-MFT

BERT-finetune-MFT-CreAT-maskany

BERT-SoftMasked-MFT

AutoCSC

We implement some architectures in recent CSC papers in autocsc.py.

For instance (Soft-Masked BERT):

from autocsc import AutoCSCSoftMasked

# Load the model, similar to huggingface transformers.
model = AutoCSCSoftMasked.from_pretrained("bert-base-chinese",
                                          cache_dir="cache")

# Go forward step.
outputs = model(src_ids=src_ids,
                attention_mask=attention_mask,
                trg_ids=trg_ids)
loss = outputs["loss"]
prd_ids = outputs["predict_ids"].tolist()

Inference for ReLM

from autocsc import AutoCSCReLM
import torch
from transformers import AutoTokenizer
from run import *


load_state_path = '../csc_model/lemon/ReLM/relm-m0.3.bin'

tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese',
                                        use_fast=True,
                                        add_prefix_space=True)

model = AutoCSCReLM.from_pretrained('bert-base-chinese',
                                    state_dict=torch.load(load_state_path),
                                    cache_dir="../cache")
max_seq_length = 256
src = ['发动机故障切纪盲目拆检']
tgt = ['发动机故障切忌盲目拆检']

def decode(input_ids):
    return tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=True)

processor = DataProcessorForRephrasing()
lines = [(list(src[i]), list(tgt[i])) for i in range(len(src))]
eval_examples = processor._create_examples(lines, 'test')
eval_features = processor.convert_examples_to_features(eval_examples, max_seq_length, tokenizer, False)
src_ids = torch.tensor([f.src_ids for f in eval_features], dtype=torch.long)
attention_mask = torch.tensor([f.attention_mask for f in eval_features], dtype=torch.long)
trg_ids = torch.tensor([f.trg_ids for f in eval_features], dtype=torch.long)

all_inputs, all_labels, all_predictions = [], [], []
with torch.no_grad():
    outputs = model(src_ids=src_ids,
                    attention_mask=attention_mask,
                    trg_ids=trg_ids)
    prd_ids = outputs["predict_ids"]
for s, t, p in zip(src_ids.tolist(), trg_ids.tolist(), prd_ids.tolist()):
    _t = [tt for tt, st in zip(t, s) if st == tokenizer.mask_token_id]
    _p = [pt for pt, st in zip(p, s) if st == tokenizer.mask_token_id]

    all_inputs += [decode(s)]
    all_labels += [decode(_t)]
    all_predictions += [decode(_p)]

print(all_inputs)
print(all_labels)
print(all_predictions)

If you have new models or suggestions for promoting our implementations, feel free to email me.

Running (set --mft for Masked-FT):

CUDA_VISIBLE_DEVICES=0 python run.py \
  --do_train \
  --do_eval \
  --train_on xxx.txt \
  --eval_on xx.txt \
  --output_dir mft \
  --max_train_steps 10000 \
  --fp16 \
  --model_type mdcspell \
  --mft

Directly testing on LEMON (including SIGHAN):

CUDA_VISIBLE_DEVICES=0 python run.py \
  --test_on_lemon ../data/lemon \
  --output_dir relm \
  --model_type relm \
  --load_state_dict relm-m0.3.bin

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages