Skip to content

Commit

Permalink
✨ ChatTTS 支持
Browse files Browse the repository at this point in the history
reference 推理 #113

- ChatTTS infer 支持 spk 中的 ref_wav 提示
  • Loading branch information
zhzLuke96 committed Jul 29, 2024
1 parent c5b5a27 commit ff4991e
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 8 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,11 @@ WIP 开发中

#### TTS

| 模型名称 | 流式级别 | 支持复刻 | 支持训练 | 支持 prompt | 实现情况 |
| ---------- | -------- | -------- | -------- | ----------- | -------- |
| ChatTTS | token 级 | ||||
| FishSpeech | 句子级 |||| |
| CosyVoice | 句子级 |||||
| 模型名称 | 流式级别 | 支持复刻 | 支持训练 | 支持 prompt | 实现情况 |
| ---------- | -------- | -------- | -------- | ----------- | ---------------------- |
| ChatTTS | token 级 | ||| |
| FishSpeech | 句子级 ||||(SFT 版本开发中 🚧) |
| CosyVoice | 句子级 |||| |

#### ASR

Expand Down
27 changes: 25 additions & 2 deletions modules/core/models/tts/ChatTtsModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import threading
from typing import Any, Generator, Union

import librosa
import numpy as np

from modules.core.models.TTSModel import TTSModel
Expand All @@ -9,6 +11,7 @@
from modules.core.pipeline.pipeline import TTSSegment
from modules.core.pipeline.processor import NP_AUDIO
from modules.core.spk.TTSSpeaker import TTSSpeaker
from modules.utils import audio_utils
from modules.utils.SeedContext import SeedContext


Expand Down Expand Up @@ -65,6 +68,20 @@ def get_infer(self, context: TTSPipelineContext):
def interrupt(self, context: TTSPipelineContext = None) -> None:
self.current_infer.interrupt()

def get_ref_wav(self, segment: TTSSegment):
spk = segment.spk
emotion = segment.emotion
ref_data = spk.get_ref(lambda x: x.emotion == emotion)
if ref_data is None:
return None, None
wav = audio_utils.bytes_to_librosa_array(
audio_bytes=ref_data.wav, sample_rate=ref_data.wav_sr
)
text = ref_data.text
# 调整采样率到 24kHz
wav = librosa.resample(wav, orig_sr=ref_data.wav_sr, target_sr=24000)
return wav, text

def generate_batch_base(
self, segments: list[TTSSegment], context: TTSPipelineContext, stream=False
) -> Union[list[NP_AUDIO], Generator[list[NP_AUDIO], Any, None]]:
Expand All @@ -85,6 +102,8 @@ def _gen():

seg0 = segments[0]
spk_emb = self.get_spk_emb(segment=seg0, context=context) if seg0.spk else None
spk_wav, txt_smp = self.get_ref_wav(seg0)
spk_smp = infer._sample_audio_speaker(spk_wav) if spk_wav is not None else None
top_P = seg0.top_p
top_K = seg0.top_k
temperature = seg0.temperature
Expand All @@ -108,7 +127,9 @@ def _gen():
with SeedContext(seed, cudnn_deterministic=False):
results = infer.generate_audio(
text=texts,
spk_emb=spk_emb,
spk_emb=spk_emb if spk_smp is None else None,
spk_smp=spk_smp,
txt_smp=txt_smp,
top_P=top_P,
top_K=top_K,
temperature=temperature,
Expand All @@ -131,7 +152,9 @@ def _gen() -> Generator[list[NP_AUDIO], None, None]:
with SeedContext(seed, cudnn_deterministic=False):
for results in infer.generate_audio_stream(
text=texts,
spk_emb=spk_emb,
spk_emb=spk_emb if spk_smp is None else None,
spk_smp=spk_smp,
txt_smp=txt_smp,
top_P=top_P,
top_K=top_K,
temperature=temperature,
Expand Down
3 changes: 2 additions & 1 deletion modules/core/models/tts/CosyVoiceModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from modules.core.spk import TTSSpeaker
from modules.devices import devices
from modules.repos_static.cosyvoice.cosyvoice.cli.model import CosyVoiceModel
from modules.utils import audio_utils
from modules.utils.SeedContext import SeedContext

max_val = 0.8
Expand Down Expand Up @@ -204,7 +205,7 @@ def spk_to_ref_wav(self, spk: TTSSpeaker, emotion: str = ""):
ref_data = spk.get_ref(lambda x: x.emotion == emotion)
if ref_data is None:
return None, None
wav = ref_data.wav
wav = audio_utils.bytes_to_librosa_array(ref_data.wav)
# 调整采样率到 16kHz
wav = librosa.resample(wav, orig_sr=ref_data.wav_sr, target_sr=target_sr)
return wav, ref_data.text
Expand Down
25 changes: 25 additions & 0 deletions modules/core/models/zoo/ChatTTSInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ChatTTSInfer:

def __init__(self, instance: Chat) -> None:
self.instance = instance
self.device = instance.device
self.dtype = instance.dtype
ChatTTSInfer.current_infer = self

if zoo.zoo_config.debug_generate:
Expand All @@ -61,6 +63,21 @@ def interrupt(cls):
cls.current_infer.instance.interrupt()
cls.logger.info("Interrupted current infer")

@torch.inference_mode()
def _sample_audio_speaker(
self, wav: Union[np.ndarray, torch.Tensor]
) -> torch.Tensor:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
wav = wav.to(device=self.device, dtype=self.dtype)
# TODO: 最好不要 autocast ,但是得改 dvae 的代码
with torch.autocast(device_type=self.device.type, dtype=self.dtype):
return (
self.instance.dvae(wav, "encode")
.squeeze_(0)
.to(device=self.device, dtype=self.dtype)
)

def infer(
self,
text: Union[str, list[str]],
Expand Down Expand Up @@ -267,6 +284,8 @@ def _generate_audio(
self,
text: Union[str, list[str]],
spk_emb: Union[None, torch.Tensor] = None,
spk_smp: Union[None, torch.Tensor] = None,
txt_smp: Union[None, str] = None,
top_P=0.7,
top_K=20,
temperature=0.3,
Expand All @@ -282,6 +301,8 @@ def _generate_audio(
params = Chat.InferCodeParams(
prompt="",
spk_emb=spk_emb,
spk_smp=spk_smp,
txt_smp=txt_smp,
top_P=top_P,
top_K=top_K,
temperature=temperature,
Expand All @@ -305,6 +326,8 @@ def generate_audio(
self,
text: Union[str, list[str]],
spk_emb: Union[None, torch.Tensor] = None,
spk_smp: Union[None, torch.Tensor] = None,
txt_smp: Union[None, str] = None,
top_P=0.7,
top_K=20,
temperature=0.3,
Expand All @@ -319,6 +342,8 @@ def generate_audio(
data = self._generate_audio(
text=text,
spk_emb=spk_emb,
spk_smp=spk_smp,
txt_smp=txt_smp,
top_P=top_P,
top_K=top_K,
temperature=temperature,
Expand Down
2 changes: 2 additions & 0 deletions modules/repos_static/ChatTTS/ChatTTS/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def apply_spk_emb(
@staticmethod
@torch.no_grad()
def _decode_prompt(prompt: str) -> torch.Tensor:
if isinstance(prompt, torch.Tensor):
return prompt
dec = b14.decode_from_string(prompt)
shp = np.frombuffer(dec[:4], dtype="<u2")
p = np.frombuffer(
Expand Down
13 changes: 13 additions & 0 deletions modules/utils/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@
from io import BytesIO

import numpy as np
import numpy.typing as npt
import pyrubberband as pyrb
import soundfile as sf
from pydub import AudioSegment, effects
import scipy.io.wavfile as wavfile
import io

INT16_MAX = np.iinfo(np.int16).max


def bytes_to_librosa_array(audio_bytes: bytes, sample_rate: int) -> npt.NDArray:
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
byte_io = io.BytesIO()
wavfile.write(byte_io, sample_rate, audio_np)
byte_io.seek(0)
_, audio_data = wavfile.read(byte_io)
audio_float = audio_data.astype(np.float32) / np.iinfo(np.int16).max
return audio_float


def audio_to_int16(audio_data: np.ndarray) -> np.ndarray:
if (
audio_data.dtype == np.float32
Expand Down

0 comments on commit ff4991e

Please sign in to comment.