Skip to content

Commit

Permalink
Update Bert-VITS2 extra.
Browse files Browse the repository at this point in the history
Add phrases_dict.
  • Loading branch information
Artrajz committed Jan 8, 2024
1 parent 72743a3 commit 1b89d8c
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 41 deletions.
6 changes: 5 additions & 1 deletion TTSManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import numpy as np
import xml.etree.ElementTree as ET

from utils.config_manager import global_config as config
import soundfile as sf
from io import BytesIO
Expand Down Expand Up @@ -442,7 +443,10 @@ def bert_vits2_infer_v2(self, state, encode=True):
sentences_list = sentence_split(state["text"], state["segment_size"])
audios = []
for sentences in sentences_list:
if state["lang"].lower() == "auto":
if model.zh_bert_extra:
infer_func = model.infer
state["lang"] = "zh"
elif state["lang"].lower() == "auto":
infer_func = model.infer_multilang
else:
infer_func = model.infer
Expand Down
47 changes: 23 additions & 24 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,25 +127,20 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):
def load_model(self, model_handler):
self.model_handler = model_handler

if self.version == "2.3":
self.net_g = SynthesizerTrn_v230(
len(symbols),
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.hps_ms.data.n_speakers,
**self.hps_ms.model,
).to(self.device)
if self.version in ["2.3", "extra"]:
Synthesizer = SynthesizerTrn_v230
else:
self.net_g = SynthesizerTrn(
len(self.symbols),
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.hps_ms.data.n_speakers,
symbols=self.symbols,
ja_bert_dim=self.ja_bert_dim,
num_tones=self.num_tones,
zh_bert_extra=self.zh_bert_extra,
**self.hps_ms.model).to(self.device)
Synthesizer = SynthesizerTrn
self.net_g = Synthesizer(
len(self.symbols),
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.hps_ms.data.n_speakers,
symbols=self.symbols,
ja_bert_dim=self.ja_bert_dim,
num_tones=self.num_tones,
zh_bert_extra=self.zh_bert_extra,
**self.hps_ms.model).to(self.device)
_ = self.net_g.eval()
bert_vits2_utils.load_checkpoint(self.model_path, self.net_g, None, skip_optimizer=True, version=self.version)

Expand Down Expand Up @@ -176,7 +171,10 @@ def get_text(self, text, language_str, hps, style_text=None, style_weight=0.7):
del word2ph
assert bert.shape[-1] == len(phone), phone

if language_str == "zh" or self.zh_bert_extra:
if self.zh_bert_extra:
zh_bert = bert
ja_bert, en_bert = None, None
elif language_str == "zh":
zh_bert = bert
ja_bert = torch.zeros(self.ja_bert_dim, len(phone))
en_bert = torch.zeros(1024, len(phone))
Expand Down Expand Up @@ -229,18 +227,19 @@ def _infer(self, id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_rat
tones = tones.to(self.device).unsqueeze(0)
lang_ids = lang_ids.to(self.device).unsqueeze(0)
zh_bert = zh_bert.to(self.device).unsqueeze(0)
ja_bert = ja_bert.to(self.device).unsqueeze(0)
en_bert = en_bert.to(self.device).unsqueeze(0)
if not self.zh_bert_extra:
ja_bert = ja_bert.to(self.device).unsqueeze(0)
en_bert = en_bert.to(self.device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device)
speakers = torch.LongTensor([int(id)]).to(self.device)
audio = self.net_g.infer(x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
zh_bert,
ja_bert,
en_bert,
zh_bert=zh_bert,
ja_bert=ja_bert,
en_bert=en_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise,
noise_scale_w=noisew,
Expand Down
9 changes: 8 additions & 1 deletion bert_vits2/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,20 @@ def load_emotion(self, max_retries=3):
retries = 0
model_path = self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"]
while retries < max_retries:
logging.info(f"Loading WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM: {model_path}")
try:
self.emotion = {}
self.emotion["model"] = EmotionModel.from_pretrained(model_path).to(self.device)
self.emotion["processor"] = Wav2Vec2Processor.from_pretrained(model_path)
self.emotion["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
self._download_model("WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM")
retries += 1
if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
self.emotion["reference_count"] += 1

Expand All @@ -184,17 +188,20 @@ def load_clap(self, max_retries=3):
retries = 0
model_path = self.model_path["CLAP_HTSAT_FUSED"]
while retries < max_retries:
logging.info(f"Loading CLAP_HTSAT_FUSED: {model_path}")
try:
self.clap = {}
self.clap["model"] = ClapModel.from_pretrained(model_path).to(self.device)
self.clap["processor"] = ClapProcessor.from_pretrained(model_path)
self.clap["reference_count"] = 1
logging.info(f"Success loading: {model_path}")
break
except Exception as e:
logging.error(f"Failed loading {model_path}. {e}")
self._download_model("CLAP_HTSAT_FUSED")
retries += 1

if retries == max_retries:
logging.error(f"Failed to load {model_path} after {max_retries} retries.")
else:
self.clap["reference_count"] += 1

Expand Down
56 changes: 48 additions & 8 deletions bert_vits2/models_v230.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch import nn
from torch.nn import functional as F
from vector_quantize_pytorch import VectorQuantize

from bert_vits2 import commons
from bert_vits2 import modules
Expand Down Expand Up @@ -341,6 +342,7 @@ def __init__(
kernel_size,
p_dropout,
gin_channels=0,
zh_bert_extra=False,
):
super().__init__()
self.n_vocab = n_vocab
Expand All @@ -359,8 +361,36 @@ def __init__(
self.language_emb = nn.Embedding(num_languages, hidden_channels)
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels ** -0.5)
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.zh_bert_extra = zh_bert_extra
if self.zh_bert_extra:
self.bert_pre_proj = nn.Conv1d(2048, 1024, 1)
self.in_feature_net = nn.Sequential(
# input is assumed to an already normalized embedding
nn.Linear(512, 1028, bias=False),
nn.GELU(),
nn.LayerNorm(1028),
*[Block(1028, 512) for _ in range(1)],
nn.Linear(1028, 512, bias=False),
# normalize before passing to VQ?
# nn.GELU(),
# nn.LayerNorm(512),
)
self.emo_vq = VectorQuantize(
dim=512,
codebook_size=64,
codebook_dim=32,
commitment_weight=0.1,
decay=0.85,
heads=32,
kmeans_iters=20,
separate_codebook_per_head=True,
stochastic_sample_codes=True,
threshold_ema_dead_code=2,
)
self.out_feature_net = nn.Linear(512, hidden_channels)
else:
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)

self.encoder = attentions.Encoder(
hidden_channels,
Expand All @@ -373,12 +403,19 @@ def __init__(
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)

def forward(self, x, x_lengths, tone, language, zh_bert, ja_bert, en_bert, g=None):
def forward(self, x, x_lengths, tone, language, zh_bert, ja_bert, en_bert, emo=None, g=None):
x = self.emb(x) + self.tone_emb(tone) + self.language_emb(language)

x +=self.bert_proj(zh_bert).transpose(1, 2)
x += self.ja_bert_proj(ja_bert).transpose(1, 2)
x += self.en_bert_proj(en_bert).transpose(1, 2)

if self.zh_bert_extra:
zh_bert = self.bert_pre_proj(zh_bert)
emo_emb = self.in_feature_net(emo)
emo_emb, _, _ = self.emo_vq(emo_emb.unsqueeze(1))
emo_emb = self.out_feature_net(emo_emb)
x += emo_emb
x += self.bert_proj(zh_bert).transpose(1, 2)
if not self.zh_bert_extra:
x += self.ja_bert_proj(ja_bert).transpose(1, 2)
x += self.en_bert_proj(en_bert).transpose(1, 2)

x *= math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
Expand Down Expand Up @@ -831,6 +868,7 @@ def __init__(
n_layers_trans_flow=4,
flow_share_parameter=False,
use_transformer_flow=True,
zh_bert_extra=False,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -873,6 +911,7 @@ def __init__(
kernel_size,
p_dropout,
gin_channels=self.enc_gin_channels,
zh_bert_extra=zh_bert_extra,
)
self.dec = Generator(
inter_channels,
Expand Down Expand Up @@ -937,6 +976,7 @@ def infer(
zh_bert,
ja_bert,
en_bert,
emo=None,
noise_scale=0.667,
length_scale=1,
noise_scale_w=0.8,
Expand All @@ -952,7 +992,7 @@ def infer(
else:
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.enc_p(
x, x_lengths, tone, language, zh_bert, ja_bert, en_bert, g=g
x, x_lengths, tone, language, zh_bert, ja_bert, en_bert, emo=emo, g=g
)
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
sdp_ratio
Expand Down
9 changes: 3 additions & 6 deletions bert_vits2/text/chinese_bert_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVIC
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = torch.nn.functional.normalize(
torch.cat(res["hidden_states"][-3:-2], -1)[0], dim=0
).cpu()
res = torch.nn.functional.normalize(torch.cat(res["hidden_states"][-3:-2], -1)[0], dim=0).cpu()
if style_text:
style_inputs = tokenizer(style_text, return_tensors="pt")
for i in style_inputs:
Expand All @@ -21,15 +19,14 @@ def get_bert_feature(text, word2ph, tokenizer, model, device=global_config.DEVIC
torch.cat(style_res["hidden_states"][-3:-2], -1)[0], dim=0
).cpu()
style_res_mean = style_res.mean(0)

assert len(word2ph) == len(text) + 2
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
if style_text:
repeat_feature = (
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
+ style_res_mean.repeat(word2phone[i], 1) * style_weight
)
else:
repeat_feature = res[i].repeat(word2phone[i], 1)
Expand Down
5 changes: 4 additions & 1 deletion utils/phrases_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import jieba
import pypinyin
from pypinyin_dict.phrase_pinyin_data import large_pinyin
from pypinyin_dict.pinyin_data import cc_cedict
from pypinyin_dict.pinyin_data import cc_cedict, kxhc1983

import config

phrases_dict = {
"一骑当千": [["yí"], ["jì"], ["dāng"], ["qiān"]],
"桔子": [["jú"], ["zǐ"]],
"重生": [["chóng"], ["shēng"]],
"重重地":[["zhòng"], ["zhòng"], ["dě"]],
"自少时":[["zì"], ["shào"], ["shí"]],
}


Expand Down

0 comments on commit 1b89d8c

Please sign in to comment.