This is the official repo for the ACL 2023 paper Rethinking Masked Language Modeling for Chinese Spelling Correction; AAAI 2024 paper Chinese Spelling Correction as Rephraing Language Model.
Fine-tuning results on some of benchmarks:
| EC-LAW | EC-MED | EC-ODW | MCSC | |
|---|---|---|---|---|
| BERT | 39.8 | 22.3 | 25.0 | 70.7 |
| MDCSpell-Masked-FT | 80.6 | 69.6 | 66.9 | 78.5 |
| Baichuan2-Masked-FT | 86.0 | 73.2 | 82.6 | 75.5 |
| ReLM | 95.6 | 89.9 | 92.3 | 83.2 |
==New==
ReLM
ReLM pre-trained model is released. It is a rephrasing language model trained based on bert-base-chinese and 34 million monolingual data.
The main idea is illustrated in the figure below. We concatenate the input and a sequence of mask tokens of the same length as the input, and train the model to rephrase the entire sentence by infilling additional slots, instead of character-to-character tagging. We also apply the masked-fine-tuning technique during training, which masks a proportion of characters in the source sentence. We will not mask source sentence in evaluation stage.
Different from BERT-MFT, ReLM is a pure language model, which optimizes the rephrasing language modeling objective instead of sequence tagging.
from autocsc import AutoCSCReLM
model = AutoCSCReLM.from_pretrained("bert-base-chinese",
state_dict=torch.load("relm-m0.3.bin"),
cache_dir="cache")Monolingual data
We share our used training data for LEMON. It contains 34 million monolingual sentences and we synthesize sentence pairs based on our confusion set in confus.
We split the data into 343 sub-files with 100,000 sentences for each. The total size of the .zip file is 1.5G.
Our code supports multiple GPUs now:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --multi_gpu run.py \
--do_train \
--do_eval \
--fp16 \
--mftLEMON (large-scale multi-domain dataset with natural spelling errors) is a novel benchmark released with our paper. All test sets are in lemon_v2.
Note: This dataset can only be used for academic research, it cannot be used for commercial purposes.
The other test sets we use in the paper are in sighan_ecspell.
The confusion sets are in confus.
Trained weights
In our paper, we train BERT for 30,000 steps, with the learning rate 5e-5 and batch size 8192. The backbone model is bert-base-chinese. We share our trained model weights to facilitate future research. We welcome researchers to develop better ones based on our models.
BERT-finetune-MFT-CreAT-maskany
We implement some architectures in recent CSC papers in autocsc.py.
For instance (Soft-Masked BERT):
from autocsc import AutoCSCSoftMasked
# Load the model, similar to huggingface transformers.
model = AutoCSCSoftMasked.from_pretrained("bert-base-chinese",
cache_dir="cache")
# Go forward step.
outputs = model(src_ids=src_ids,
attention_mask=attention_mask,
trg_ids=trg_ids)
loss = outputs["loss"]
prd_ids = outputs["predict_ids"].tolist()Inference for ReLM
from autocsc import AutoCSCReLM
import torch
from transformers import AutoTokenizer
from run import *
load_state_path = '../csc_model/lemon/ReLM/relm-m0.3.bin'
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese',
use_fast=True,
add_prefix_space=True)
model = AutoCSCReLM.from_pretrained('bert-base-chinese',
state_dict=torch.load(load_state_path),
cache_dir="../cache")
max_seq_length = 256
src = ['发动机故障切纪盲目拆检']
tgt = ['发动机故障切忌盲目拆检']
def decode(input_ids):
return tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=True)
processor = DataProcessorForRephrasing()
lines = [(list(src[i]), list(tgt[i])) for i in range(len(src))]
eval_examples = processor._create_examples(lines, 'test')
eval_features = processor.convert_examples_to_features(eval_examples, max_seq_length, tokenizer, False)
src_ids = torch.tensor([f.src_ids for f in eval_features], dtype=torch.long)
attention_mask = torch.tensor([f.attention_mask for f in eval_features], dtype=torch.long)
trg_ids = torch.tensor([f.trg_ids for f in eval_features], dtype=torch.long)
all_inputs, all_labels, all_predictions = [], [], []
with torch.no_grad():
outputs = model(src_ids=src_ids,
attention_mask=attention_mask,
trg_ids=trg_ids)
prd_ids = outputs["predict_ids"]
for s, t, p in zip(src_ids.tolist(), trg_ids.tolist(), prd_ids.tolist()):
_t = [tt for tt, st in zip(t, s) if st == tokenizer.mask_token_id]
_p = [pt for pt, st in zip(p, s) if st == tokenizer.mask_token_id]
all_inputs += [decode(s)]
all_labels += [decode(_t)]
all_predictions += [decode(_p)]
print(all_inputs)
print(all_labels)
print(all_predictions)If you have new models or suggestions for promoting our implementations, feel free to email me.
Running (set --mft for Masked-FT):
CUDA_VISIBLE_DEVICES=0 python run.py \
--do_train \
--do_eval \
--train_on xxx.txt \
--eval_on xx.txt \
--output_dir mft \
--max_train_steps 10000 \
--fp16 \
--model_type mdcspell \
--mftDirectly testing on LEMON (including SIGHAN):
CUDA_VISIBLE_DEVICES=0 python run.py \
--test_on_lemon ../data/lemon \
--output_dir relm \
--model_type relm \
--load_state_dict relm-m0.3.bin