From 89ade6e27bca77910a0274e2c771329684bdc100 Mon Sep 17 00:00:00 2001 From: Artrajz <969242373@qq.com> Date: Wed, 3 Jan 2024 00:16:21 +0800 Subject: [PATCH] Update Bert-VITS2 v2.3 (#121) * Add download link * Update Bert-VITS2 v2.3 --- README.md | 26 +- README_zh.md | 49 +- TTSManager.py | 3 +- bert_vits2/bert_vits2.py | 61 +- bert_vits2/model_handler.py | 56 +- bert_vits2/models_v230.py | 985 ++++++++++++++++++++++ bert_vits2/text/chinese_bert.py | 17 +- bert_vits2/text/cleaner.py | 5 +- bert_vits2/text/english_bert_mock_v200.py | 20 +- bert_vits2/text/english_v230.py | 493 +++++++++++ bert_vits2/text/japanese_bert.py | 19 +- config.py | 1 + tts_app/static/js/index.js | 18 + tts_app/templates/pages/index.html | 17 +- tts_app/voice_api/views.py | 9 +- utils/config_manager.py | 9 +- 16 files changed, 1706 insertions(+), 82 deletions(-) create mode 100644 bert_vits2/models_v230.py create mode 100644 bert_vits2/text/english_v230.py diff --git a/README.md b/README.md index 9ac5406..d817b04 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,20 @@ pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/w ## Linux The installation process is similar, but I don't have the environment to test it. +# WebUI + +## Inference Frontend + + http://127.0.0.1:23456 + +*Port is modifiable under the default setting of port 23456. + +## Admin Backend + +The default address is http://127.0.0.1:23456/admin. + +The initial username and password can be found at the bottom of the config.yml file after the first startup. + # Function Options Explanation ## Disable the Admin Backend @@ -262,11 +276,6 @@ To ensure compatibility with the Bert-VITS2 model, modify the config.json file b ... ``` -# Admin Backend -The default address is http://127.0.0.1:23456/admin. - -The initial username and password can be found at the bottom of the config.yml file after the first startup. - # API ## GET @@ -372,8 +381,11 @@ After enabling it, you need to add the `api_key` parameter in GET requests and a | SDP noise | noisew | false | From `config.yml` | float | Stochastic Duration Predictor noise, controlling the length of phoneme pronunciation. | | Segment Size | segment_size | false | From `config.yml` | int | Divide the text into paragraphs based on punctuation marks, and combine them into one paragraph when the length exceeds segment_size. If segment_size<=0, the text will not be divided into paragraphs. | | SDP/DP mix ratio | sdp_ratio | false | From `config.yml` | int | The theoretical proportion of SDP during synthesis, the higher the ratio, the larger the variance in synthesized voice tone. | -| Emotion | emotion | false | None | | Available for Bert-VITS2 v2.1, ranging from 0 to 9 | -| Reference Audio | reference_audio | false | None | | Available for Bert-VITS2 v2.1 | +| Emotion | emotion | false | None | int | Available for Bert-VITS2 v2.1, ranging from 0 to 9 | +| Emotion reference Audio | reference_audio | false | None | | Bert-VITS2 v2.1 uses reference audio to control the synthesized audio's emotion | +|Text Prompt|text_prompt|false|None|str|Bert-VITS2 v2.2 text prompt used for emotion control| +|Style Text|style_text|false|None|str|Bert-VITS2 v2.3 text prompt used for emotion control| +|Style Text Weight|style_weight|false|From `config.yml`|float|Bert-VITS2 v2.3 text prompt weight used for prompt weighting| ## SSML (Speech Synthesis Markup Language) diff --git a/README_zh.md b/README_zh.md index 53ba563..7b66ebd 100644 --- a/README_zh.md +++ b/README_zh.md @@ -204,6 +204,20 @@ pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/w 安装过程类似,可以查阅网上的安装资料。也可以直接使用docker部署脚本中的gpu版本。 +# WebUI + +## 推理前端 + +http://127.0.0.1:23456 + +*在默认端口为23456的情况下,端口可修改 + +## 管理员后台 + +默认为http://127.0.0.1:23456/admin + +初始账号密码在初次启动后,在config.yml最下方可找到。 + # 功能选项说明 ## 关闭管理员后台 @@ -268,12 +282,6 @@ pip install pyopenjtalk -i https://pypi.artrajz.cn/simple ... ``` -# 管理员后台 - -默认为http://127.0.0.1:23456/admin - -初始账号密码在初次启动后,在config.yml最下方可找到。 - # API ## GET @@ -371,19 +379,22 @@ pip install pyopenjtalk -i https://pypi.artrajz.cn/simple ## Bert-VITS2语音合成 -| Name | Parameter | Is must | Default | Type | Instruction | -| ------------- | --------------- | ------- | -------------------- | ----- | ------------------------------------------------------------ | -| 合成文本 | text | true | | str | 需要合成语音的文本。 | -| 角色id | id | false | 从`config.yml`中获取 | int | 即说话人id。 | -| 音频格式 | format | false | 从`config.yml`中获取 | str | 支持wav,ogg,silk,mp3,flac | -| 文本语言 | lang | false | 从`config.yml`中获取 | str | auto为自动识别语言模式,也是默认模式,但目前只支持识别整段文本的语言,无法细分到每个句子。其余可选语言zh和ja。 | -| 语音长度/语速 | length | false | 从`config.yml`中获取 | float | 调节语音长度,相当于调节语速,该数值越大语速越慢。 | -| 噪声 | noise | false | 从`config.yml`中获取 | float | 样本噪声,控制合成的随机性。 | -| sdp噪声 | noisew | false | 从`config.yml`中获取 | float | 随机时长预测器噪声,控制音素发音长度。 | -| 分段阈值 | segment_size | false | 从`config.yml`中获取 | int | 按标点符号分段,加起来大于segment_size时为一段文本。segment_size<=0表示不分段。 | -| SDP/DP混合比 | sdp_ratio | false | 从`config.yml`中获取 | int | SDP在合成时的占比,理论上此比率越高,合成的语音语调方差越大。 | -| 情感控制 | emotion | false | None | | Bert-VITS2 v2.1可用,范围为0-9 | -| 情感参考音频 | reference_audio | false | None | | Bert-VITS2 v2.1可用 | +| Name | Parameter | Is must | Default | Type | Instruction | +| -------------- | --------------- | ------- | -------------------- | ----- | ------------------------------------------------------------ | +| 合成文本 | text | true | | str | 需要合成语音的文本。 | +| 角色id | id | false | 从`config.yml`中获取 | int | 即说话人id。 | +| 音频格式 | format | false | 从`config.yml`中获取 | str | 支持wav,ogg,silk,mp3,flac | +| 文本语言 | lang | false | 从`config.yml`中获取 | str | auto为自动识别语言模式,也是默认模式,但目前只支持识别整段文本的语言,无法细分到每个句子。其余可选语言zh和ja。 | +| 语音长度/语速 | length | false | 从`config.yml`中获取 | float | 调节语音长度,相当于调节语速,该数值越大语速越慢。 | +| 噪声 | noise | false | 从`config.yml`中获取 | float | 样本噪声,控制合成的随机性。 | +| sdp噪声 | noisew | false | 从`config.yml`中获取 | float | 随机时长预测器噪声,控制音素发音长度。 | +| 分段阈值 | segment_size | false | 从`config.yml`中获取 | int | 按标点符号分段,加起来大于segment_size时为一段文本。segment_size<=0表示不分段。 | +| SDP/DP混合比 | sdp_ratio | false | 从`config.yml`中获取 | int | SDP在合成时的占比,理论上此比率越高,合成的语音语调方差越大。 | +| 情感控制 | emotion | false | None | int | Bert-VITS2 v2.1可用,范围为0-9 | +| 情感参考音频 | reference_audio | false | None | | Bert-VITS2 v2.1 使用参考音频来控制合成音频的情感 | +| 文本提示词 | text_prompt | false | None | str | Bert-VITS2 v2.2 文本提示词,用于控制情感 | +| 文本提示词 | style_text | false | None | str | Bert-VITS2 v2.3 文本提示词,用于控制情感 | +| 文本提示词权重 | style_weight | false | 从`config.yml`中获取 | float | Bert-VITS2 v2.3 文本提示词,用于提示词权重 | ## SSML语音合成标记语言 目前支持的元素与属性 diff --git a/TTSManager.py b/TTSManager.py index 486d2ff..a101288 100644 --- a/TTSManager.py +++ b/TTSManager.py @@ -374,7 +374,8 @@ def bert_vits2_infer(self, state, encode=True): for sentence in sentences: audio = model.infer(sentence, state["id"], lang, state["sdp_ratio"], state["noise"], state["noise"], length, emotion=state["emotion"], - reference_audio=state["reference_audio"], text_prompt=state["text_prompt"]) + reference_audio=state["reference_audio"], text_prompt=state["text_prompt"], + style_text=state["style_text"], style_weight=state["style_weight"]) audios.append(audio) audio = np.concatenate(audios) diff --git a/bert_vits2/bert_vits2.py b/bert_vits2/bert_vits2.py index 23a8e2b..e8bc2ae 100644 --- a/bert_vits2/bert_vits2.py +++ b/bert_vits2/bert_vits2.py @@ -1,3 +1,5 @@ +import logging + import torch from bert_vits2 import commons @@ -5,6 +7,7 @@ from bert_vits2.clap_wrapper import get_clap_audio_feature, get_clap_text_feature from bert_vits2.get_emo import get_emo from bert_vits2.models import SynthesizerTrn +from bert_vits2.models_v230 import SynthesizerTrn as SynthesizerTrn_v230 from bert_vits2.text import * from bert_vits2.text.cleaner import clean_text from bert_vits2.utils import process_legacy_versions @@ -80,14 +83,19 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs): self.num_tones = num_tones if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) - - # else: - # self.hps_ms.model.n_layers_trans_flow = 4 - # self.hps_ms.model.emotion_embedding = 1 - # self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) - # self.num_tones = num_tones - # if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) - # if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) + elif self.version in ["2.3", "2.3.0"]: + self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) + self.num_tones = num_tones + self.text_extra_str_map.update({"en": "_v230"}) + if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) + if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) + else: + logging.debug("Version information not found. Loaded as the newest version: v2.3.") + self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) + self.num_tones = num_tones + self.text_extra_str_map.update({"en": "_v230"}) + if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) + if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) if "zh" in self.lang: self.bert_model_names.update({"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE"}) @@ -99,22 +107,31 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs): def load_model(self, model_handler): self.model_handler = model_handler - self.net_g = SynthesizerTrn( - len(self.symbols), - self.hps_ms.data.filter_length // 2 + 1, - self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, - n_speakers=self.hps_ms.data.n_speakers, - symbols=self.symbols, - ja_bert_dim=self.ja_bert_dim, - num_tones=self.num_tones, - **self.hps_ms.model).to(self.device) + if self.version in ["2.3", "2.3.0"]: + self.net_g = SynthesizerTrn_v230( + len(symbols), + self.hps_ms.data.filter_length // 2 + 1, + self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, + n_speakers=self.hps_ms.data.n_speakers, + **self.hps_ms.model, + ).to(self.device) + else: + self.net_g = SynthesizerTrn( + len(self.symbols), + self.hps_ms.data.filter_length // 2 + 1, + self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, + n_speakers=self.hps_ms.data.n_speakers, + symbols=self.symbols, + ja_bert_dim=self.ja_bert_dim, + num_tones=self.num_tones, + **self.hps_ms.model).to(self.device) _ = self.net_g.eval() bert_vits2_utils.load_checkpoint(self.model_path, self.net_g, None, skip_optimizer=True, version=self.version) def get_speakers(self): return self.speakers - def get_text(self, text, language_str, hps): + def get_text(self, text, language_str, hps, style_text=None, style_weight=0.7): clean_text_lang_str = language_str + self.text_extra_str_map.get(language_str, "") bert_feature_lang_str = language_str + self.bert_extra_str_map.get(language_str, "") @@ -132,8 +149,9 @@ def get_text(self, text, language_str, hps): word2ph[i] = word2ph[i] * 2 word2ph[0] += 1 + style_text = None if style_text == "" else style_text bert = self.model_handler.get_bert_feature(norm_text, word2ph, bert_feature_lang_str, - self.bert_model_names[language_str]) + self.bert_model_names[language_str], style_text, style_weight) del word2ph assert bert.shape[-1] == len(phone), phone @@ -173,8 +191,9 @@ def get_emo_(self, reference_audio, emotion): return emo def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None, - skip_start=False, skip_end=False, text_prompt=None, **kwargs): - zh_bert, ja_bert, en_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms) + skip_start=False, skip_end=False, text_prompt=None, style_text=None, style_weigth=0.7, **kwargs): + zh_bert, ja_bert, en_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms, style_text, + style_weigth) if self.hps_ms.model.emotion_embedding == 1: emo = self.get_emo_(reference_audio, emotion).to(self.device).unsqueeze(0) diff --git a/bert_vits2/model_handler.py b/bert_vits2/model_handler.py index 08ba350..1a08290 100644 --- a/bert_vits2/model_handler.py +++ b/bert_vits2/model_handler.py @@ -14,7 +14,6 @@ from bert_vits2.text.japanese_bert_v200 import get_bert_feature as ja_bert_v200 from bert_vits2.text.english_bert_mock_v200 import get_bert_feature as en_bert_v200 - class ModelHandler: def __init__(self, device): self.DOWNLOAD_PATHS = { @@ -47,10 +46,12 @@ def __init__(self, device): "https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin", ], "WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": [ - + "https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin", + "https://hf-mirror.com/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin", ], "CLAP_HTSAT_FUSED": [ - + "https://huggingface.co/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true", + "https://hf-mirror.com/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true", ] } @@ -62,7 +63,8 @@ def __init__(self, device): "DEBERTA_V3_LARGE": "dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5", "SPM": "c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd", "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": "bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201", - "CLAP_HTSAT_FUSED": "" + "WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": "176d9d1ce29a8bddbab44068b9c1c194c51624c7f1812905e01355da58b18816", + "CLAP_HTSAT_FUSED": "1ed5d0215d887551ddd0a49ce7311b21429ebdf1e6a129d4e68f743357225253", } self.model_path = { "CHINESE_ROBERTA_WWM_EXT_LARGE": os.path.join(config.ABS_PATH, @@ -141,28 +143,45 @@ def load_bert(self, bert_model_name, max_retries=3): tokenizer, model, count = self.bert_models[bert_model_name] self.bert_models[bert_model_name] = (tokenizer, model, count + 1) - def load_emotion(self): + def load_emotion(self, max_retries=3): """Bert-VITS2 v2.1 EmotionModel""" if self.emotion is None: from transformers import Wav2Vec2Processor from bert_vits2.get_emo import EmotionModel - self.emotion = {} - self.emotion["model"] = EmotionModel.from_pretrained( - self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"]).to(self.device) - self.emotion["processor"] = Wav2Vec2Processor.from_pretrained( - self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"]) - self.emotion["reference_count"] = 1 + retries = 0 + model_path = self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"] + while retries < max_retries: + try: + self.emotion = {} + self.emotion["model"] = EmotionModel.from_pretrained(model_path).to(self.device) + self.emotion["processor"] = Wav2Vec2Processor.from_pretrained(model_path) + self.emotion["reference_count"] = 1 + break + except Exception as e: + logging.error(f"Failed loading {model_path}. {e}") + self._download_model("WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM") + retries += 1 else: self.emotion["reference_count"] += 1 - def load_clap(self): + def load_clap(self, max_retries=3): """Bert-VITS2 v2.2 ClapModel""" if self.clap is None: from transformers import ClapModel, ClapProcessor - self.clap = {} - self.clap["model"] = ClapModel.from_pretrained(self.model_path["CLAP_HTSAT_FUSED"]).to(self.device) - self.clap["processor"] = ClapProcessor.from_pretrained(self.model_path["CLAP_HTSAT_FUSED"]) - self.clap["reference_count"] = 1 + retries = 0 + model_path = self.model_path["CLAP_HTSAT_FUSED"] + while retries < max_retries: + try: + self.clap = {} + self.clap["model"] = ClapModel.from_pretrained(model_path).to(self.device) + self.clap["processor"] = ClapProcessor.from_pretrained(model_path) + self.clap["reference_count"] = 1 + break + except Exception as e: + logging.error(f"Failed loading {model_path}. {e}") + self._download_model("CLAP_HTSAT_FUSED") + retries += 1 + else: self.clap["reference_count"] += 1 @@ -173,9 +192,10 @@ def get_bert_model(self, bert_model_name): tokenizer, model, _ = self.bert_models[bert_model_name] return tokenizer, model - def get_bert_feature(self, norm_text, word2ph, language, bert_model_name): + def get_bert_feature(self, norm_text, word2ph, language, bert_model_name, style_text=None, style_weight=0.7): tokenizer, model = self.get_bert_model(bert_model_name) - bert_feature = self.lang_bert_func_map[language](norm_text, word2ph, tokenizer, model, self.device) + bert_feature = self.lang_bert_func_map[language](norm_text, word2ph, tokenizer, model, self.device, style_text, + style_weight) return bert_feature def release_bert(self, bert_model_name): diff --git a/bert_vits2/models_v230.py b/bert_vits2/models_v230.py new file mode 100644 index 0000000..56211be --- /dev/null +++ b/bert_vits2/models_v230.py @@ -0,0 +1,985 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from bert_vits2 import commons +from bert_vits2 import modules +from bert_vits2 import attentions + +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from bert_vits2.commons import init_weights, get_padding +from bert_vits2.text import symbols, num_tones, num_languages + + +class DurationDiscriminator(nn.Module): # vits2 + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = modules.LayerNorm(filter_channels) + self.dur_proj = nn.Conv1d(1, filter_channels, 1) + + self.LSTM = nn.LSTM( + 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True + ) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + self.output_layer = nn.Sequential( + nn.Linear(2 * filter_channels, 1), nn.Sigmoid() + ) + + def forward_probability(self, x, dur): + dur = self.dur_proj(dur) + x = torch.cat([x, dur], dim=1) + x = x.transpose(1, 2) + x, _ = self.LSTM(x) + output_prob = self.output_layer(x) + return output_prob + + def forward(self, x, x_mask, dur_r, dur_hat, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + + output_probs = [] + for dur in [dur_r, dur_hat]: + output_prob = self.forward_probability(x, dur) + output_probs.append(output_prob) + + return output_probs + + +class TransformerCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + share_parameter=False, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + + self.wn = ( + attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + isflow=True, + gin_channels=self.gin_channels, + ) + if share_parameter + else None + ) + + for i in range(n_flows): + self.flows.append( + modules.TransformerCouplingLayer( + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout, + filter_channels, + mean_only=True, + wn_sharing_parameter=self.wn, + gin_channels=self.gin_channels, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class StochasticDurationPredictor(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class Bottleneck(nn.Sequential): + def __init__(self, in_dim, hidden_dim): + c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False) + c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False) + super().__init__(*[c_fc1, c_fc2]) + + +class Block(nn.Module): + def __init__(self, in_dim, hidden_dim) -> None: + super().__init__() + self.norm = nn.LayerNorm(in_dim) + self.mlp = MLP(in_dim, hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.mlp(self.norm(x)) + return x + + +class MLP(nn.Module): + def __init__(self, in_dim, hidden_dim): + super().__init__() + self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False) + self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False) + + def forward(self, x: torch.Tensor): + x = F.silu(self.c_fc1(x)) * self.c_fc2(x) + x = self.c_proj(x) + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + gin_channels=0, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + self.emb = nn.Embedding(len(symbols), hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + self.tone_emb = nn.Embedding(num_tones, hidden_channels) + nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels ** -0.5) + self.language_emb = nn.Embedding(num_languages, hidden_channels) + nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels ** -0.5) + self.bert_proj = nn.Conv1d(1024, hidden_channels, 1) + self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1) + self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + gin_channels=self.gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, tone, language, zh_bert, ja_bert, en_bert, g=None): + zh_bert_emb = self.bert_proj(zh_bert).transpose(1, 2) + ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2) + en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2) + x = ( + self.emb(x) + + self.tone_emb(tone) + + self.language_emb(language) + + zh_bert_emb + + ja_bert_emb + + en_bert_emb + ) * math.sqrt( + self.hidden_channels + ) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + + x = self.encoder(x * x_mask, x_mask, g=g) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for layer in self.ups: + remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for layer in self.convs: + x = layer(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for layer in self.convs: + x = layer(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class WavLMDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__( + self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False + ): + super(WavLMDiscriminator, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.pre = norm_f( + Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0) + ) + + self.convs = nn.ModuleList( + [ + norm_f( + nn.Conv1d( + initial_channel, initial_channel * 2, kernel_size=5, padding=2 + ) + ), + norm_f( + nn.Conv1d( + initial_channel * 2, + initial_channel * 4, + kernel_size=5, + padding=2, + ) + ), + norm_f( + nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2) + ), + ] + ) + + self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) + + def forward(self, x): + x = self.pre(x) + + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + x = torch.flatten(x, 1, -1) + + return x + + +class ReferenceEncoder(nn.Module): + """ + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + """ + + def __init__(self, spec_channels, gin_channels=0): + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501 + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, gin_channels) + + def forward(self, inputs, mask=None): + N = inputs.size(0) + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + for conv in self.convs: + out = conv(out) + # out = wn(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=256, + gin_channels=256, + use_sdp=True, + n_flow_layer=4, + n_layers_trans_flow=4, + flow_share_parameter=False, + use_transformer_flow=True, + **kwargs + ): + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.n_layers_trans_flow = n_layers_trans_flow + self.use_spk_conditioned_encoder = kwargs.get( + "use_spk_conditioned_encoder", True + ) + self.use_sdp = use_sdp + self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False) + self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01) + self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6) + self.current_mas_noise_scale = self.mas_noise_scale_initial + if self.use_spk_conditioned_encoder and gin_channels > 0: + self.enc_gin_channels = gin_channels + self.enc_p = TextEncoder( + n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + gin_channels=self.enc_gin_channels, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + if use_transformer_flow: + self.flow = TransformerCouplingBlock( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers_trans_flow, + 5, + p_dropout, + n_flow_layer, + gin_channels=gin_channels, + share_parameter=flow_share_parameter, + ) + else: + self.flow = ResidualCouplingBlock( + inter_channels, + hidden_channels, + 5, + 1, + n_flow_layer, + gin_channels=gin_channels, + ) + self.sdp = StochasticDurationPredictor( + hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels + ) + self.dp = DurationPredictor( + hidden_channels, 256, 3, 0.5, gin_channels=gin_channels + ) + + if n_speakers >= 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + else: + self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) + + def infer( + self, + x, + x_lengths, + sid, + tone, + language, + bert, + ja_bert, + en_bert, + noise_scale=0.667, + length_scale=1, + noise_scale_w=0.8, + max_len=None, + sdp_ratio=0, + y=None, + **kwargs, + ): + # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert) + # g = self.gst(y) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1) + x, m_p, logs_p, x_mask = self.enc_p( + x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g + ) + logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * ( + sdp_ratio + ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to( + x_mask.dtype + ) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:, :, :max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) diff --git a/bert_vits2/text/chinese_bert.py b/bert_vits2/text/chinese_bert.py index c9fa707..84eeed7 100644 --- a/bert_vits2/text/chinese_bert.py +++ b/bert_vits2/text/chinese_bert.py @@ -3,19 +3,32 @@ from utils.config_manager import global_config -def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE): +def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE, style_text=None, style_weight=0.7, ): with torch.no_grad(): inputs = tokenizer(text, return_tensors='pt') for i in inputs: inputs[i] = inputs[i].to(device) res = model(**inputs, output_hidden_states=True) res = torch.cat(res['hidden_states'][-3:-2], -1)[0].cpu() + if style_text: + style_inputs = tokenizer(style_text, return_tensors="pt") + for i in style_inputs: + style_inputs[i] = style_inputs[i].to(device) + style_res = model(**style_inputs, output_hidden_states=True) + style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() + style_res_mean = style_res.mean(0) assert len(word2ph) == len(text) + 2 word2phone = word2ph phone_level_feature = [] for i in range(len(word2phone)): - repeat_feature = res[i].repeat(word2phone[i], 1) + if style_text: + repeat_feature = ( + res[i].repeat(word2phone[i], 1) * (1 - style_weight) + + style_res_mean.repeat(word2phone[i], 1) * style_weight + ) + else: + repeat_feature = res[i].repeat(word2phone[i], 1) phone_level_feature.append(repeat_feature) phone_level_feature = torch.cat(phone_level_feature, dim=0) diff --git a/bert_vits2/text/cleaner.py b/bert_vits2/text/cleaner.py index 0d52a1f..346407a 100644 --- a/bert_vits2/text/cleaner.py +++ b/bert_vits2/text/cleaner.py @@ -1,5 +1,5 @@ from bert_vits2.text import chinese, japanese, english, cleaned_text_to_sequence, japanese_v111, chinese_v100, \ - japanese_v200, english_v200 + japanese_v200, english_v200, english_v230 language_module_map = { 'zh': chinese, @@ -8,7 +8,8 @@ 'ja_v111': japanese_v111, 'zh_v100': chinese_v100, 'ja_v200': japanese_v200, - 'en_v200': english_v200 + 'en_v200': english_v200, + 'en_v230': english_v230, } diff --git a/bert_vits2/text/english_bert_mock_v200.py b/bert_vits2/text/english_bert_mock_v200.py index 6ef94b8..3d66d12 100644 --- a/bert_vits2/text/english_bert_mock_v200.py +++ b/bert_vits2/text/english_bert_mock_v200.py @@ -4,18 +4,32 @@ -def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE): +def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE, style_text=None, style_weight=0.7): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) res = model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() - # assert len(word2ph) == len(text)+2 + if style_text: + style_inputs = tokenizer(style_text, return_tensors="pt") + for i in style_inputs: + style_inputs[i] = style_inputs[i].to(device) + style_res = model(**style_inputs, output_hidden_states=True) + style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() + style_res_mean = style_res.mean(0) + assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph)) + assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph)) word2phone = word2ph phone_level_feature = [] for i in range(len(word2phone)): - repeat_feature = res[i].repeat(word2phone[i], 1) + if style_text: + repeat_feature = ( + res[i].repeat(word2phone[i], 1) * (1 - style_weight) + + style_res_mean.repeat(word2phone[i], 1) * style_weight + ) + else: + repeat_feature = res[i].repeat(word2phone[i], 1) phone_level_feature.append(repeat_feature) phone_level_feature = torch.cat(phone_level_feature, dim=0) diff --git a/bert_vits2/text/english_v230.py b/bert_vits2/text/english_v230.py new file mode 100644 index 0000000..4d670cc --- /dev/null +++ b/bert_vits2/text/english_v230.py @@ -0,0 +1,493 @@ +import pickle +import os +from g2p_en import G2p +from transformers import DebertaV2Tokenizer + +from bert_vits2.text import symbols +from bert_vits2.text.symbols import punctuation + +current_file_path = os.path.dirname(__file__) +CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep") +CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle") +_g2p = G2p() +LOCAL_PATH = "./bert/deberta-v3-large" +# tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH) + +arpa = { + "AH0", + "S", + "AH1", + "EY2", + "AE2", + "EH0", + "OW2", + "UH0", + "NG", + "B", + "G", + "AY0", + "M", + "AA0", + "F", + "AO0", + "ER2", + "UH1", + "IY1", + "AH2", + "DH", + "IY0", + "EY1", + "IH0", + "K", + "N", + "W", + "IY2", + "T", + "AA1", + "ER1", + "EH2", + "OY0", + "UH2", + "UW1", + "Z", + "AW2", + "AW1", + "V", + "UW2", + "AA2", + "ER", + "AW0", + "UW0", + "R", + "OW1", + "EH1", + "ZH", + "AE0", + "IH2", + "IH", + "Y", + "JH", + "P", + "AY1", + "EY0", + "OY2", + "TH", + "HH", + "D", + "ER0", + "CH", + "AO1", + "AE1", + "AO2", + "OY1", + "AY2", + "IH1", + "OW0", + "L", + "SH", +} + + +def post_replace_ph(ph): + rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "…": "...", + "···": "...", + "・・・": "...", + "v": "V", + } + if ph in rep_map.keys(): + ph = rep_map[ph] + if ph in symbols: + return ph + if ph not in symbols: + ph = "UNK" + return ph + + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + ".": ".", + "…": "...", + "···": "...", + "・・・": "...", + "·": ",", + "・": ",", + "、": ",", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "−": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", +} + + +def replace_punctuation(text): + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + # replaced_text = re.sub( + # r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" + # + "".join(punctuation) + # + r"]+", + # "", + # replaced_text, + # ) + + return replaced_text + + +def read_dict(): + g2p_dict = {} + start_line = 49 + with open(CMU_DICT_PATH) as f: + line = f.readline() + line_index = 1 + while line: + if line_index >= start_line: + line = line.strip() + word_split = line.split(" ") + word = word_split[0] + + syllable_split = word_split[1].split(" - ") + g2p_dict[word] = [] + for syllable in syllable_split: + phone_split = syllable.split(" ") + g2p_dict[word].append(phone_split) + + line_index = line_index + 1 + line = f.readline() + + return g2p_dict + + +def cache_dict(g2p_dict, file_path): + with open(file_path, "wb") as pickle_file: + pickle.dump(g2p_dict, pickle_file) + + +def get_dict(): + if os.path.exists(CACHE_PATH): + with open(CACHE_PATH, "rb") as pickle_file: + g2p_dict = pickle.load(pickle_file) + else: + g2p_dict = read_dict() + cache_dict(g2p_dict, CACHE_PATH) + + return g2p_dict + + +eng_dict = get_dict() + + +def refine_ph(phn): + tone = 0 + if re.search(r"\d$", phn): + tone = int(phn[-1]) + 1 + phn = phn[:-1] + else: + tone = 3 + return phn.lower(), tone + + +def refine_syllables(syllables): + tones = [] + phonemes = [] + for phn_list in syllables: + for i in range(len(phn_list)): + phn = phn_list[i] + phn, tone = refine_ph(phn) + phonemes.append(phn) + tones.append(tone) + return phonemes, tones + + +import re +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + +# List of (ipa, lazy ipa) pairs: +_lazy_ipa = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("r", "ɹ"), + ("æ", "e"), + ("ɑ", "a"), + ("ɔ", "o"), + ("ð", "z"), + ("θ", "s"), + ("ɛ", "e"), + ("ɪ", "i"), + ("ʊ", "u"), + ("ʒ", "ʥ"), + ("ʤ", "ʥ"), + ("ˈ", "↓"), + ] +] + +# List of (ipa, lazy ipa2) pairs: +_lazy_ipa2 = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("r", "ɹ"), + ("ð", "z"), + ("θ", "s"), + ("ʒ", "ʑ"), + ("ʤ", "dʑ"), + ("ˈ", "↓"), + ] +] + +# List of (ipa, ipa2) pairs +_ipa_to_ipa2 = [ + (re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")] +] + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words( + num, andword="", zero="oh", group=2 + ).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + + +def text_normalize(text): + text = normalize_numbers(text) + text = replace_punctuation(text) + text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text) + return text + + +def distribute_phone(n_phone, n_word): + phones_per_word = [0] * n_word + for task in range(n_phone): + min_tasks = min(phones_per_word) + min_index = phones_per_word.index(min_tasks) + phones_per_word[min_index] += 1 + return phones_per_word + + +def sep_text(text): + words = re.split(r"([,;.\?\!\s+])", text) + words = [word for word in words if word.strip() != ""] + return words + + +def text_to_words(text, tokenizer): + tokens = tokenizer.tokenize(text) + words = [] + for idx, t in enumerate(tokens): + if t.startswith("▁"): + words.append([t[1:]]) + else: + if t in punctuation: + if idx == len(tokens) - 1: + words.append([f"{t}"]) + else: + if ( + not tokens[idx + 1].startswith("▁") + and tokens[idx + 1] not in punctuation + ): + if idx == 0: + words.append([]) + words[-1].append(f"{t}") + else: + words.append([f"{t}"]) + else: + if idx == 0: + words.append([]) + words[-1].append(f"{t}") + return words + + +def g2p(text, tokenizer): + phones = [] + tones = [] + phone_len = [] + # words = sep_text(text) + # tokens = [tokenizer.tokenize(i) for i in words] + words = text_to_words(text, tokenizer) + + for word in words: + temp_phones, temp_tones = [], [] + if len(word) > 1: + if "'" in word: + word = ["".join(word)] + for w in word: + if w in punctuation: + temp_phones.append(w) + temp_tones.append(0) + continue + if w.upper() in eng_dict: + phns, tns = refine_syllables(eng_dict[w.upper()]) + temp_phones += [post_replace_ph(i) for i in phns] + temp_tones += tns + # w2ph.append(len(phns)) + else: + phone_list = list(filter(lambda p: p != " ", _g2p(w))) + phns = [] + tns = [] + for ph in phone_list: + if ph in arpa: + ph, tn = refine_ph(ph) + phns.append(ph) + tns.append(tn) + else: + phns.append(ph) + tns.append(0) + temp_phones += [post_replace_ph(i) for i in phns] + temp_tones += tns + phones += temp_phones + tones += temp_tones + phone_len.append(len(temp_phones)) + # phones = [post_replace_ph(i) for i in phones] + + word2ph = [] + for token, pl in zip(words, phone_len): + word_len = len(token) + + aaa = distribute_phone(pl, word_len) + word2ph += aaa + + phones = ["_"] + phones + ["_"] + tones = [0] + tones + [0] + word2ph = [1] + word2ph + [1] + assert len(phones) == len(tones), text + assert len(phones) == sum(word2ph), text + + return phones, tones, word2ph + + +def get_bert_feature(text, word2ph): + from bert_vits2.text import english_bert_mock + + return english_bert_mock.get_bert_feature(text, word2ph) + + +if __name__ == "__main__": + # print(get_dict()) + # print(eng_word_to_phoneme("hello")) + print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")) + # all_phones = set() + # for k, syllables in eng_dict.items(): + # for group in syllables: + # for ph in group: + # all_phones.add(ph) + # print(all_phones) diff --git a/bert_vits2/text/japanese_bert.py b/bert_vits2/text/japanese_bert.py index d7ecdbe..3358d22 100644 --- a/bert_vits2/text/japanese_bert.py +++ b/bert_vits2/text/japanese_bert.py @@ -6,20 +6,35 @@ LOCAL_PATH = "./bert/deberta-v2-large-japanese-char-wwm" -def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE): +def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVICE, style_text=None, style_weight=0.7): text = "".join(text2sep_kata(text)[0]) + if style_text: + style_text = "".join(text2sep_kata(style_text)[0]) with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) res = model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() + if style_text: + style_inputs = tokenizer(style_text, return_tensors="pt") + for i in style_inputs: + style_inputs[i] = style_inputs[i].to(device) + style_res = model(**style_inputs, output_hidden_states=True) + style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() + style_res_mean = style_res.mean(0) assert len(word2ph) == len(text) + 2 word2phone = word2ph phone_level_feature = [] for i in range(len(word2phone)): - repeat_feature = res[i].repeat(word2phone[i], 1) + if style_text: + repeat_feature = ( + res[i].repeat(word2phone[i], 1) * (1 - style_weight) + + style_res_mean.repeat(word2phone[i], 1) * style_weight + ) + else: + repeat_feature = res[i].repeat(word2phone[i], 1) phone_level_feature.append(repeat_feature) phone_level_feature = torch.cat(phone_level_feature, dim=0) diff --git a/config.py b/config.py index 9b17199..4f65cf5 100644 --- a/config.py +++ b/config.py @@ -130,3 +130,4 @@ LENGTH_ZH = 0 LENGTH_JA = 0 LENGTH_EN = 0 +STYLE_WEIGHT = 0.7 diff --git a/tts_app/static/js/index.js b/tts_app/static/js/index.js index 2f5dde4..84942cc 100644 --- a/tts_app/static/js/index.js +++ b/tts_app/static/js/index.js @@ -68,6 +68,8 @@ function getLink() { let sdp_ratio = null; let emotion = null; let text_prompt = ""; + let style_text = ""; + let style_weight = ""; if (currentModelPage == 1) { streaming = document.getElementById('streaming1'); url += "/voice/vits?text=" + text + "&id=" + id; @@ -79,6 +81,8 @@ function getLink() { streaming = document.getElementById('streaming3'); emotion = document.getElementById('input_emotion3').value; text_prompt = document.getElementById('input_text_prompt3').value; + style_text = document.getElementById('input_style_text3').value; + style_weight = document.getElementById('input_style_weight3').value; url += "/voice/bert-vits2?text=" + text + "&id=" + id; } else { @@ -125,6 +129,10 @@ function getLink() { url += "&emotion=" + emotion; if (text_prompt !== null && text_prompt !== "") url += "&text_prompt=" + text_prompt; + if (style_text !== null && style_text !== "") + url += "&style_text=" + style_text; + if (style_weight !== null && style_weight !== "") + url += "&style_weight=" + style_weight; } return url; @@ -202,6 +210,8 @@ function setAudioSourceByPost() { let length_en = 0; let emotion = null; let text_prompt = ""; + let style_text = ""; + let style_weight = ""; if (currentModelPage == 1) { url = baseUrl + "/voice/vits"; @@ -218,6 +228,8 @@ function setAudioSourceByPost() { length_en = $("#input_length_en3").val(); emotion = $("#input_emotion3").val(); text_prompt = $("#input_text_prompt3").val(); + style_text = $("#input_style_text3").val(); + style_weight = $("#input_style_weight3").val(); } // 添加其他配置参数到 FormData @@ -245,6 +257,12 @@ function setAudioSourceByPost() { if (currentModelPage == 3 && text_prompt) { formData.append('text_prompt', text_prompt); } + if (currentModelPage == 3 && style_text) { + formData.append('style_text', style_text); + } + if (currentModelPage == 3 && style_weight) { + formData.append('style_weight', style_weight); + } let downloadButton = document.getElementById("downloadButton" + currentModelPage); diff --git a/tts_app/templates/pages/index.html b/tts_app/templates/pages/index.html index 4db19d9..16899a2 100644 --- a/tts_app/templates/pages/index.html +++ b/tts_app/templates/pages/index.html @@ -351,11 +351,26 @@