-
Notifications
You must be signed in to change notification settings - Fork 788
快速训练
hnluo edited this page Apr 21, 2023
·
5 revisions
import os
import json
import shutil
from modelscope.pipelines import pipeline
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.utils.constant import Tasks
from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.compute_wer import compute_wer
def modelscope_finetune(params):
if not os.path.exists(params["model_dir"]):
os.makedirs(params["model_dir"], exist_ok=True)
# dataset split ["train", "validation"]
ds_dict = MsDataset.load(params["dataset_name"], namespace='speech_asr')
kwargs = dict(
model=params["modelscope_model_name"],
data_dir=ds_dict,
work_dir=params["model_dir"],
max_epoch=1)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
required_files = ["am.mvn", "decoding.yaml", "configuration.json"]
for file_name in required_files:
shutil.copy(os.path.join(pretrained_model_path, file_name),
os.path.join(params["model_dir"], file_name))
def modelscope_infer(params):
# prepare for decoding
with open(os.path.join(params["model_dir"], "configuration.json")) as f:
config_dict = json.load(f)
config_dict["model"]["am_model_name"] = params["decoding_model_name"]
with open(os.path.join(params["model_dir"], "configuration.json"), "w") as f:
json.dump(config_dict, f, indent=4, separators=(',', ': '))
decoding_path = os.path.join(params["model_dir"], "decode_results")
if os.path.exists(decoding_path):
shutil.rmtree(decoding_path)
os.mkdir(decoding_path)
# decoding
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model=params["model_dir"],
output_dir=decoding_path,
batch_size=64
)
audio_in = os.path.join(params["test_data_dir"], "wav.scp")
inference_pipeline(audio_in=audio_in)
# computer CER if GT text is set
text_in = os.path.join(params["test_data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
os.system("tail -n 3 {}".format(os.path.join(decoding_path, "text.cer")))
if __name__ == '__main__':
finetune_params = {}
finetune_params["modelscope_model_name"] = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
finetune_params["dataset_name"] = "speech_asr_aishell1_subset"
finetune_params["model_dir"] = "./checkpoint"
modelscope_finetune(finetune_params)
infer_params = {}
infer_params["model_dir"] = "./checkpoint"
infer_params["decoding_model_name"] = "1epoch.pb"
infer_params["test_data_dir"] = "./checkpoint/data/validation"
modelscope_infer(infer_params)
modelscope模型资源下载->modelscope dataset下载->模型训练->模型测试并计算CER
python finetune.py
- modelscope_model_name:需要finetune的modelscope模型名字
- dataset_name:modelscope dataset名字
- model_dir:训练模型保存目录
tree ./checkpoint/
./checkpoint/
├── 1epoch.pb
├── tensorboard
├── train
└── valid
- 1epoch.pb:训练1epoch模型文件
- tensorboard:训练tensorboard保存目录,tensorboard打开方式:tensorboard --logdir checkpoint/tensorboard/train/,tensorboard查看方式:打开网页输入训练服务器 ip:6006
- model_dir:解码模型目录
- test_data_dir:测试数据目录
- decoding_model_name:解码模型名字
tree ./checkpoint/decode_results/
./checkpoint/decode_results/
├── text.cer
├── 1best_recog
├── rtf
└── text
└── score
- text.cer:CER统计文件
BAC009S0724W0495.wav(nwords=13,cor=13,ins=0,del=0,sub=0) corr:100.00%,cer:0.00%
ref: 筹备了一系列新展并同时亮相
hyp: 筹备了一系列新展并同时亮相
%WER 3.51 [ 177 / 5037, 4 ins, 0 del, 173 sub ]
%SER 28.81 [ 102 / 354 ]
- rtf:解码每句话的耗时
BAC009S0724W0495.wav decoding, feature length: 3075, forward_time: 0.4731, rtf: 0.0026
rtf_avf decoding, feature length total: 29815.0, forward_time total: 4.7479, rtf avg: 0.0027
- score:解码每句话的得分
BAC009S0724W0495.wav tensor(-1.2249, device='cuda:0')
- text:解码结果
BAC009S0724W0495.wav 筹备了一系列新展并同时亮相