Skip to content

Commit

Permalink
✨ tts api support stream #5
Browse files Browse the repository at this point in the history
- 简单支持 stream
- 缺少 enhance
- 缺少 adjust
- 缺少 lru_cache
- 缺少 batch generate
  • Loading branch information
zhzLuke96 committed Jun 20, 2024
1 parent ea7399f commit 15e0b2c
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 14 deletions.
2 changes: 1 addition & 1 deletion modules/ChatTTS/ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _infer(
stream_text=False,
):

assert self.check_model(use_decoder=use_decoder)
# assert self.check_model(use_decoder=use_decoder)

if not isinstance(text, list):
text = [text]
Expand Down
20 changes: 19 additions & 1 deletion modules/api/impl/handler/AudioHandler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import io
from typing import Generator

import numpy as np
import soundfile as sf
Expand All @@ -10,7 +11,24 @@

class AudioHandler:
def enqueue(self) -> tuple[np.ndarray, int]:
raise NotImplementedError
raise NotImplementedError("Method 'enqueue' must be implemented by subclass")

def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]:
raise NotImplementedError(
"Method 'enqueue_stream' must be implemented by subclass"
)

def enqueue_to_stream(self, format: AudioFormat) -> Generator[bytes, None, None]:
for audio_data, sample_rate in self.enqueue_stream():
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format="wav")
buffer.seek(0)

if format == AudioFormat.mp3:
buffer = api_utils.wav_to_mp3(buffer)

binary = buffer.read()
yield binary

def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
audio_data, sample_rate = self.enqueue()
Expand Down
35 changes: 35 additions & 0 deletions modules/api/impl/handler/TTSHandler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Generator
import numpy as np

from modules.api.impl.handler.AudioHandler import AudioHandler
Expand All @@ -8,8 +9,13 @@
from modules.normalization import text_normalize
from modules.speaker import Speaker
from modules.synthesize_audio import synthesize_audio
from modules.synthesize_stream import synthesize_stream
from modules.utils.audio import apply_prosody_to_audio_data

import logging

logger = logging.getLogger(__name__)


class TTSHandler(AudioHandler):
def __init__(
Expand Down Expand Up @@ -95,3 +101,32 @@ def enqueue(self) -> tuple[np.ndarray, int]:
)

return audio_data, sample_rate

def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]:
text = text_normalize(self.text_content)
tts_config = self.tts_config
infer_config = self.infer_config
adjust_config = self.adjest_config
enhancer_config = self.enhancer_config

if enhancer_config.enabled:
logger.warning(
"enhancer_config is enabled, but it is not supported in stream mode"
)

gen = synthesize_stream(
text,
spk=self.spk,
temperature=tts_config.temperature,
top_P=tts_config.top_p,
top_K=tts_config.top_k,
prompt1=tts_config.prompt1,
prompt2=tts_config.prompt2,
prefix=tts_config.prefix,
infer_seed=infer_config.seed,
spliter_threshold=infer_config.spliter_threshold,
end_of_sentence=infer_config.eos,
)

for sr, wav in gen:
yield wav, sr
22 changes: 18 additions & 4 deletions modules/api/impl/tts_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.speaker import Speaker

import logging

logger = logging.getLogger(__name__)


class TTSParams(BaseModel):
text: str = Query(..., description="Text to synthesize")
Expand Down Expand Up @@ -44,6 +48,8 @@ class TTSParams(BaseModel):
pitch: float = Query(0, description="Pitch of the audio")
volume_gain: float = Query(0, description="Volume gain of the audio")

stream: bool = Query(False, description="Stream the audio")


async def synthesize_tts(params: TTSParams = Depends()):
try:
Expand Down Expand Up @@ -132,14 +138,22 @@ async def synthesize_tts(params: TTSParams = Depends()):
adjust_config=adjust_config,
enhancer_config=enhancer_config,
)

buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))

media_type = f"audio/{params.format}"
if params.format == "mp3":
media_type = "audio/mpeg"
return StreamingResponse(buffer, media_type=media_type)

if params.stream:
if infer_config.batch_size != 1:
# 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略
logger.warning(
f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1"
)

buffer_gen = handler.enqueue_to_stream(format=AudioFormat(params.format))
return StreamingResponse(buffer_gen, media_type=media_type)
else:
buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
return StreamingResponse(buffer, media_type=media_type)
except Exception as e:
import logging

Expand Down
92 changes: 84 additions & 8 deletions modules/generate_audio.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import gc
import logging
from typing import Union
from typing import Generator, Union

import numpy as np
import torch

from modules import config, models
from modules.ChatTTS import ChatTTS
from modules.devices import devices
from modules.speaker import Speaker
from modules.utils.cache import conditional_cache
from modules.utils.SeedContext import SeedContext

logger = logging.getLogger(__name__)

SAMPLE_RATE = 24000


def generate_audio(
text: str,
Expand Down Expand Up @@ -42,20 +45,18 @@ def generate_audio(
return (sample_rate, wav)


@torch.inference_mode()
def generate_audio_batch(
def parse_infer_params(
texts: list[str],
chat_tts: ChatTTS.Chat,
temperature: float = 0.3,
top_P: float = 0.7,
top_K: float = 20,
spk: Union[int, Speaker] = -1,
infer_seed: int = -1,
use_decoder: bool = True,
prompt1: str = "",
prompt2: str = "",
prefix: str = "",
):
chat_tts = models.load_chat_tts()
params_infer_code = {
"spk_emb": None,
"temperature": temperature,
Expand Down Expand Up @@ -97,18 +98,93 @@ def generate_audio_batch(
}
)

return params_infer_code


@torch.inference_mode()
def generate_audio_batch(
texts: list[str],
temperature: float = 0.3,
top_P: float = 0.7,
top_K: float = 20,
spk: Union[int, Speaker] = -1,
infer_seed: int = -1,
use_decoder: bool = True,
prompt1: str = "",
prompt2: str = "",
prefix: str = "",
):
chat_tts = models.load_chat_tts()
params_infer_code = parse_infer_params(
texts=texts,
chat_tts=chat_tts,
temperature=temperature,
top_P=top_P,
top_K=top_K,
spk=spk,
infer_seed=infer_seed,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
)

with SeedContext(infer_seed, True):
wavs = chat_tts.generate_audio(
texts, params_infer_code, use_decoder=use_decoder
texts=texts, params_infer_code=params_infer_code, use_decoder=use_decoder
)

if config.auto_gc:
devices.torch_gc()
gc.collect()

return [(SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)) for wav in wavs]


# TODO: generate_audio_stream 也应该支持 lru cache
@torch.inference_mode()
def generate_audio_stream(
text: str,
temperature: float = 0.3,
top_P: float = 0.7,
top_K: float = 20,
spk: Union[int, Speaker] = -1,
infer_seed: int = -1,
use_decoder: bool = True,
prompt1: str = "",
prompt2: str = "",
prefix: str = "",
) -> Generator[tuple[int, np.ndarray], None, None]:
chat_tts = models.load_chat_tts()
texts = [text]
params_infer_code = parse_infer_params(
texts=texts,
chat_tts=chat_tts,
temperature=temperature,
top_P=top_P,
top_K=top_K,
spk=spk,
infer_seed=infer_seed,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
)

with SeedContext(infer_seed, True):
wavs_gen = chat_tts.generate_audio(
prompt=texts,
params_infer_code=params_infer_code,
use_decoder=use_decoder,
stream=True,
)

sample_rate = 24000
for wav in wavs_gen:
yield [SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)]

if config.auto_gc:
devices.torch_gc()
gc.collect()

return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
return


lru_cache_enabled = False
Expand Down
42 changes: 42 additions & 0 deletions modules/synthesize_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import io
from typing import Generator, Union

import numpy as np

from modules import generate_audio as generate
from modules.SentenceSplitter import SentenceSplitter
from modules.speaker import Speaker


def synthesize_stream(
text: str,
temperature: float = 0.3,
top_P: float = 0.7,
top_K: float = 20,
spk: Union[int, Speaker] = -1,
infer_seed: int = -1,
use_decoder: bool = True,
prompt1: str = "",
prompt2: str = "",
prefix: str = "",
spliter_threshold: int = 100,
end_of_sentence="",
) -> Generator[tuple[int, np.ndarray], None, None]:
spliter = SentenceSplitter(spliter_threshold)
sentences = spliter.parse(text)

for sentence in sentences:
wav_gen = generate.generate_audio_stream(
text=sentence + end_of_sentence,
temperature=temperature,
top_P=top_P,
top_K=top_K,
spk=spk,
infer_seed=infer_seed,
use_decoder=use_decoder,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
)
for sr, wav in wav_gen:
yield sr, wav

0 comments on commit 15e0b2c

Please sign in to comment.