Skip to content

Commit

Permalink
更新配置文件系统 (#129)
Browse files Browse the repository at this point in the history
* Update config

* Update config
Add autoLoad

* Update config

* Update config

* Update Bert-VITS2 fp16

* Fix dimensional_emotion_model
Update empty folder
Move the text parameter to the end of the url.

* Update default_parameter

* Fix download issue and add new proxy

* Update device info

* Update version process

* Update configuration page

* Update api key setting

* Fix model load

* Update map_location

* Update docs

* Update docs
  • Loading branch information
Artrajz authored Jan 20, 2024
1 parent 48ba9ac commit e1daa46
Show file tree
Hide file tree
Showing 113 changed files with 1,339 additions and 1,117 deletions.
11 changes: 8 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
/logs/
/cache/
/upload/
/vits/text/chinese_dialect_lexicons/
/vits/bert/prosody_model.pt
config.yml
**/pytorch_model.bin
**/spm.model
/**/*.pt
phrases_dict.txt
/config.yml
/config.yaml
/data/emotional/dimensional_emotion_model/model.onnx
/data/hubert_soft/hubert-soft-0d54a1f4.pt
/data/emotional/dimensional_emotion_npy/
/data/bert/vits_chinese_bert/prosody_model.pt
/data/emotional/dimensional_emotion_npy/
132 changes: 68 additions & 64 deletions README.md

Large diffs are not rendered by default.

128 changes: 66 additions & 62 deletions README_zh.md

Large diffs are not rendered by default.

41 changes: 21 additions & 20 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,49 @@

from utils.data_utils import clean_folder
from utils.phrases_dict import phrases_dict_init
from tts_app import frontend, voice_api, auth, admin
from utils.config_manager import global_config
from tts_app.frontend.views import frontend
from tts_app.voice_api.views import voice_api
from tts_app.auth.views import auth
from tts_app.admin.views import admin

from contants import config

app = Flask(__name__, template_folder=os.path.join(os.path.dirname(__file__), 'tts_app', 'templates'),
static_folder=os.path.join(os.path.dirname(__file__), 'tts_app', 'static'))

app.config.from_pyfile("config.py")
app.config.update(global_config)
# app.config.update(config)

phrases_dict_init()

csrf = CSRFProtect(app)
# 禁用tts api请求的CSRF防护
csrf.exempt(voice_api)

if app.config.get("IS_ADMIN_ENABLED", False):
if config.system.is_admin_enabled:
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = 'auth.login'


@login_manager.user_loader
def load_user(user_id):
users = app.config["users"]["admin"]
for user in users.values():
if user.get_id() == user_id:
return user
admin = config.admin
if admin.get_id() == user_id:
return admin
return None

# Initialize scheduler
scheduler = APScheduler()
scheduler.init_app(app)
if app.config.get("CLEAN_INTERVAL_SECONDS", 3600) > 0:
if config.system.clean_interval_seconds > 0:
scheduler.start()

app.register_blueprint(frontend, url_prefix='/')
app.register_blueprint(voice_api, url_prefix='/voice')
if app.config.get("IS_ADMIN_ENABLED", False):
app.register_blueprint(auth, url_prefix=app.config.get("ADMIN_ROUTE", "/admin"))
app.register_blueprint(admin, url_prefix=app.config.get("ADMIN_ROUTE", "/admin"))
if config.system.is_admin_enabled:
app.register_blueprint(auth, url_prefix=config.system.admin_route)
app.register_blueprint(admin, url_prefix=config.system.admin_route)


def create_folders(paths):
Expand All @@ -55,19 +58,17 @@ def create_folders(paths):
os.makedirs(path, exist_ok=True)


create_folders([app.config["UPLOAD_FOLDER"],
app.config["CACHE_PATH"],
os.path.join(app.config["ABS_PATH"], "Model")
])
create_folders([os.path.join(config.abs_path, config.system.upload_folder),
os.path.join(config.abs_path, config.system.cache_path), ])


# regular cleaning
@scheduler.task('interval', id='clean_task', seconds=app.config.get("CLEAN_INTERVAL_SECONDS", 3600),
@scheduler.task('interval', id='clean_task', seconds=config.system.clean_interval_seconds,
misfire_grace_time=900)
def clean_task():
clean_folder(app.config["UPLOAD_FOLDER"])
clean_folder(app.config["CACHE_PATH"])
clean_folder(os.path.join(config.abs_path, config.system.upload_folder))
clean_folder(os.path.join(config.abs_path, config.system.cache_path))


if __name__ == '__main__':
app.run(host='0.0.0.0', port=app.config.get("PORT", 23456), debug=app.config.get("DEBUG", False))
app.run(host=config.http_service.host, port=config.http_service.port, debug=config.http_service.debug)
10 changes: 7 additions & 3 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from bert_vits2.text import *
from bert_vits2.text.cleaner import clean_text
from bert_vits2.utils import process_legacy_versions
from contants import config
from utils import get_hparams_from_file
from utils.sentence import split_by_language

Expand Down Expand Up @@ -92,22 +93,25 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):
self.num_tones = num_tones
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"})
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"})

elif self.version in ["2.3", "2.3.0"]:
self.version = "2.3"
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"])
self.num_tones = num_tones
self.text_extra_str_map.update({"en": "_v230"})
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"})
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"})
elif self.version.lower().replace("-", "_") in ["extra", "zh_clap"]:

elif self.version is not None and self.version.lower().replace("-", "_") in ["extra", "zh_clap"]:
self.version = "extra"
self.hps_ms.model.emotion_embedding = 2
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = ["zh"]
self.num_tones = num_tones
self.zh_bert_extra = True
self.bert_model_names.update({"zh": "Erlangshen-MegatronBert-1.3B-Chinese"})
self.bert_model_names.update({"zh": "Erlangshen_MegatronBert_1.3B_Chinese"})
self.bert_extra_str_map.update({"zh": "_extra"})

else:
logging.debug("Version information not found. Loaded as the newest version: v2.3.")
self.version = "2.3"
Expand Down Expand Up @@ -214,7 +218,7 @@ def _get_clap(self, reference_audio, text_prompt):
emo = get_clap_audio_feature(reference_audio, self.model_handler.clap_model,
self.model_handler.clap_processor, self.device)
else:
if text_prompt is None: text_prompt = "Happy"
if text_prompt is None: text_prompt = config.bert_vits2_config.text_prompt
emo = get_clap_text_feature(text_prompt, self.model_handler.clap_model,
self.model_handler.clap_processor, self.device)
emo = torch.squeeze(emo, dim=1).unsqueeze(0)
Expand Down
4 changes: 2 additions & 2 deletions bert_vits2/get_emo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Wav2Vec2PreTrainedModel,
)

from utils.config_manager import global_config
from contants import config


class RegressionHead(nn.Module):
Expand Down Expand Up @@ -81,7 +81,7 @@ def process_func(

def get_emo(audio, emotion_model, processor):
wav, sr = librosa.load(audio, 16000)
device = global_config["DEVICE"]
device = config.system.device
return process_func(
np.expand_dims(wav, 0).astype(np.float),
sr,
Expand Down
58 changes: 36 additions & 22 deletions bert_vits2/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizer, MegatronBertModel

import config
from contants import config
from utils.download import download_file
from bert_vits2.text.chinese_bert import get_bert_feature as zh_bert
from bert_vits2.text.english_bert_mock import get_bert_feature as en_bert
Expand All @@ -17,7 +17,7 @@


class ModelHandler:
def __init__(self, device):
def __init__(self, device=config.system.device):
self.DOWNLOAD_PATHS = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": [
"https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin",
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, device):
"https://huggingface.co/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true",
"https://hf-mirror.com/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true",
],
"Erlangshen-MegatronBert-1.3B-Chinese": [
"Erlangshen_MegatronBert_1.3B_Chinese": [
"https://huggingface.co/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin",
"https://hf-mirror.com/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin",
]
Expand All @@ -71,22 +71,27 @@ def __init__(self, device):
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": "bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201",
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": "176d9d1ce29a8bddbab44068b9c1c194c51624c7f1812905e01355da58b18816",
"CLAP_HTSAT_FUSED": "1ed5d0215d887551ddd0a49ce7311b21429ebdf1e6a129d4e68f743357225253",
"Erlangshen-MegatronBert-1.3B-Chinese": "3456bb8f2c7157985688a4cb5cecdb9e229cb1dcf785b01545c611462ffe3579",
"Erlangshen_MegatronBert_1.3B_Chinese": "3456bb8f2c7157985688a4cb5cecdb9e229cb1dcf785b01545c611462ffe3579",
}
self.model_path = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": os.path.join(config.ABS_PATH,
"bert_vits2/bert/chinese-roberta-wwm-ext-large"),
"BERT_BASE_JAPANESE_V3": os.path.join(config.ABS_PATH, "bert_vits2/bert/bert-base-japanese-v3"),
"BERT_LARGE_JAPANESE_V2": os.path.join(config.ABS_PATH, "bert_vits2/bert/bert-large-japanese-v2"),
"DEBERTA_V2_LARGE_JAPANESE": os.path.join(config.ABS_PATH, "bert_vits2/bert/deberta-v2-large-japanese"),
"DEBERTA_V3_LARGE": os.path.join(config.ABS_PATH, "bert_vits2/bert/deberta-v3-large"),
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": os.path.join(config.ABS_PATH,
"bert_vits2/bert/deberta-v2-large-japanese-char-wwm"),
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": os.path.join(config.ABS_PATH,
"bert_vits2/emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"),
"CLAP_HTSAT_FUSED": os.path.join(config.ABS_PATH, "bert_vits2/emotional/clap-htsat-fused"),
"Erlangshen-MegatronBert-1.3B-Chinese": os.path.join(config.ABS_PATH,
"bert_vits2/bert/Erlangshen-MegatronBert-1.3B-Chinese"),
"CHINESE_ROBERTA_WWM_EXT_LARGE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.chinese_roberta_wwm_ext_large),
"BERT_BASE_JAPANESE_V3": os.path.join(config.abs_path, config.system.data_path,
config.model_config.bert_base_japanese_v3),
"BERT_LARGE_JAPANESE_V2": os.path.join(config.abs_path, config.system.data_path,
config.model_config.bert_large_japanese_v2),
"DEBERTA_V2_LARGE_JAPANESE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v2_large_japanese),
"DEBERTA_V3_LARGE": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v3_large),
"DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": os.path.join(config.abs_path, config.system.data_path,
config.model_config.deberta_v2_large_japanese_char_wwm),
"WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": os.path.join(config.abs_path, config.system.data_path,
config.model_config.wav2vec2_large_robust_12_ft_emotion_msp_dim),
"CLAP_HTSAT_FUSED": os.path.join(config.abs_path, config.system.data_path,
config.model_config.clap_htsat_fused),
"Erlangshen_MegatronBert_1.3B_Chinese": os.path.join(config.abs_path, config.system.data_path,
config.model_config.erlangshen_MegatronBert_1_3B_Chinese),
}

self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert, "ja_v111": ja_bert_v111,
Expand All @@ -96,6 +101,13 @@ def __init__(self, device):
self.emotion = None
self.clap = None
self.device = device
if config.bert_vits2_config.torch_data_type != "":
if config.bert_vits2_config.torch_data_type.lower() in ["float16","fp16"]:
self.torch_dtype = torch.float16
elif config.bert_vits2_config.torch_data_type.lower() in ["int8"]:
self.torch_dtype = torch.int8
else:
self.torch_dtype = None

@property
def emotion_model(self):
Expand Down Expand Up @@ -133,12 +145,14 @@ def load_bert(self, bert_model_name, max_retries=3):
model_path = self.model_path[bert_model_name]
logging.info(f"Loading BERT model: {model_path}")
try:
if bert_model_name == "bert_model_name":
tokenizer = BertTokenizer.from_pretrained(model_path)
model = MegatronBertModel.from_pretrained(model_path).to(self.device)
if bert_model_name == "Erlangshen_MegatronBert_1.3B_Chinese":
tokenizer = BertTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype)
model = MegatronBertModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to(
self.device)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForMaskedLM.from_pretrained(model_path).to(self.device)
tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype)
model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=self.torch_dtype).to(
self.device)
self.bert_models[bert_model_name] = (tokenizer, model, 1) # 初始化引用计数为1
logging.info(f"Success loading: {model_path}")
break
Expand Down
9 changes: 5 additions & 4 deletions bert_vits2/text/chinese_bert.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import torch

from utils.config_manager import global_config
from contants import config


def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE, style_text=None, style_weight=0.7, **kwargs):
def get_bert_feature(text, word2ph, tokenizer, model, device=config.system.device, style_text=None, style_weight=0.7,
**kwargs):
with torch.no_grad():
inputs = tokenizer(text, return_tensors='pt')
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = torch.cat(res['hidden_states'][-3:-2], -1)[0].cpu()
res = torch.cat(res['hidden_states'][-3:-2], -1)[0].float().cpu()
if style_text:
style_inputs = tokenizer(style_text, return_tensors="pt")
for i in style_inputs:
style_inputs[i] = style_inputs[i].to(device)
style_res = model(**style_inputs, output_hidden_states=True)
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].float().cpu()
style_res_mean = style_res.mean(0)

assert len(word2ph) == len(text) + 2
Expand Down
17 changes: 9 additions & 8 deletions bert_vits2/text/chinese_bert_extra.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
import torch

from utils.config_manager import global_config
from contants import config


def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE, style_text=None, style_weight=0.7, **kwargs):
def get_bert_feature(text, word2ph, tokenizer, model, device=config.system.device, style_text=None, style_weight=0.7,
**kwargs):
with torch.no_grad():
inputs = tokenizer(text, return_tensors='pt')
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = torch.nn.functional.normalize(torch.cat(res["hidden_states"][-3:-2], -1)[0], dim=0).cpu()
res = torch.nn.functional.normalize(torch.cat(res["hidden_states"][-3:-2], -1)[0], dim=0).float().cpu()
if style_text:
style_inputs = tokenizer(style_text, return_tensors="pt")
for i in style_inputs:
style_inputs[i] = style_inputs[i].to(device)
style_res = model(**style_inputs, output_hidden_states=True)
style_res = torch.nn.functional.normalize(
torch.cat(style_res["hidden_states"][-3:-2], -1)[0], dim=0
).cpu()
).float().cpu()
style_res_mean = style_res.mean(0)
assert len(word2ph) == len(text) + 2
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
if style_text:
repeat_feature = (
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
)
else:
repeat_feature = res[i].repeat(word2phone[i], 1)
Expand All @@ -39,7 +40,7 @@ def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVIC

if __name__ == '__main__':

word_level_feature = torch.rand(38, 2048) # 12个词,每个词2048维特征
word_level_feature = torch.rand(38, 2048) # 12个词,每个词2048维特征
word2phone = [1, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2,
2, 2, 2, 1]

Expand All @@ -56,4 +57,4 @@ def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVIC
phone_level_feature.append(repeat_feature)

phone_level_feature = torch.cat(phone_level_feature, dim=0)
print(phone_level_feature.shape) # torch.Size([36, 2048])
print(phone_level_feature.shape) # torch.Size([36, 2048])
10 changes: 5 additions & 5 deletions bert_vits2/text/english_bert_mock.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import torch

from utils.config_manager import global_config
from contants import config



def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE, style_text=None, style_weight=0.7, **kwargs):
def get_bert_feature(text, word2ph, tokenizer, model, device=config.system.device, style_text=None, style_weight=0.7,
**kwargs):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].float().cpu()
if style_text:
style_inputs = tokenizer(style_text, return_tensors="pt")
for i in style_inputs:
style_inputs[i] = style_inputs[i].to(device)
style_res = model(**style_inputs, output_hidden_states=True)
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].float().cpu()
style_res_mean = style_res.mean(0)
assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
word2phone = word2ph
Expand Down
Loading

0 comments on commit e1daa46

Please sign in to comment.