Skip to content

Commit

Permalink
🔥 clear code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 9, 2024
1 parent c05035d commit e8a1d7b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 122 deletions.
63 changes: 6 additions & 57 deletions modules/ChatTTS/ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import logging
from functools import partial
from omegaconf import OmegaConf

import torch
from vocos import Vocos
from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .utils.gpu_utils import select_device
from .utils.infer_utils import (
count_invalid_characters,
detect_language,
Expand Down Expand Up @@ -107,9 +105,7 @@ def _load(
dtype_gpt: torch.dtype = None,
dtype_decoder: torch.dtype = None,
):
if not device:
device = select_device(4096)
self.logger.log(logging.INFO, f"use {device}")
assert device is not None, "device should not be None"

dtype_vocos = dtype_vocos or dtype
dtype_dvae = dtype_dvae or dtype
Expand Down Expand Up @@ -179,23 +175,13 @@ def infer(
params_refine_text={},
params_infer_code={"prompt": "[speed_5]"},
use_decoder=True,
do_text_normalization=True,
lang=None,
):

assert self.check_model(use_decoder=use_decoder)

if not isinstance(text, list):
text = [text]

if do_text_normalization:
for i, t in enumerate(text):
_lang = detect_language(t) if lang is None else lang
self.init_normalizer(_lang)
text[i] = self.normalizer[_lang](t)
if _lang == "zh":
text[i] = apply_half2full_map(text[i])

for i, t in enumerate(text):
reserved_tokens = self.pretrain_models[
"tokenizer"
Expand Down Expand Up @@ -251,23 +237,13 @@ def refiner_prompt(
self,
text,
params_refine_text={},
do_text_normalization=True,
lang=None,
) -> str:

# assert self.check_model(use_decoder=False)

if not isinstance(text, list):
text = [text]

if do_text_normalization:
for i, t in enumerate(text):
_lang = detect_language(t) if lang is None else lang
self.init_normalizer(_lang)
text[i] = self.normalizer[_lang](t)
if _lang == "zh":
text[i] = apply_half2full_map(text[i])

for i, t in enumerate(text):
reserved_tokens = self.pretrain_models[
"tokenizer"
Expand Down Expand Up @@ -305,7 +281,10 @@ def generate_audio(
prompt = [params_infer_code.get("prompt", "") + i for i in prompt]
params_infer_code.pop("prompt", "")
result = infer_code(
self.pretrain_models, prompt, **params_infer_code, return_hidden=use_decoder
self.pretrain_models,
prompt,
return_hidden=use_decoder,
**params_infer_code,
)

if use_decoder:
Expand All @@ -326,37 +305,7 @@ def generate_audio(
def sample_random_speaker(
self,
) -> torch.Tensor:

assert self.pretrain_models["gpt"] is not None, "gpt model not loaded"
dim = self.pretrain_models["gpt"].gpt.layers[0].mlp.gate_proj.in_features
std, mean = self.pretrain_models["spk_stat"].chunk(2)
return torch.randn(dim, device=std.device) * std + mean

def init_normalizer(self, lang):

if lang not in self.normalizer:
if lang == "zh":
try:
from tn.chinese.normalizer import Normalizer
except:
self.logger.log(
logging.WARNING,
f"Package WeTextProcessing not found! \
Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
)
self.normalizer[lang] = Normalizer().normalize
else:
try:
from nemo_text_processing.text_normalization.normalize import (
Normalizer,
)
except:
self.logger.log(
logging.WARNING,
f"Package nemo_text_processing not found! \
Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
)
self.normalizer[lang] = partial(
Normalizer(input_case="cased", lang=lang).normalize,
verbose=False,
punct_post_process=True,
)
40 changes: 0 additions & 40 deletions modules/ChatTTS/ChatTTS/experimental/llm.py

This file was deleted.

25 changes: 0 additions & 25 deletions modules/ChatTTS/ChatTTS/utils/gpu_utils.py

This file was deleted.

0 comments on commit e8a1d7b

Please sign in to comment.