diff --git a/data/load_json_spk.py b/data/load_json_spk.py index 96ab248..6da6af9 100644 --- a/data/load_json_spk.py +++ b/data/load_json_spk.py @@ -1,6 +1,6 @@ import json -from modules.speaker import speaker_mgr +from modules.core.speaker import speaker_mgr # 出处: https://github.com/2noise/ChatTTS/issues/238 data = json.load(open("./data/slct_voice_240605.json", "r")) diff --git a/modules/ChatTTS/ChatTTS/model/gpt.py b/modules/ChatTTS/ChatTTS/model/gpt.py index 0dd9a04..611bf9c 100644 --- a/modules/ChatTTS/ChatTTS/model/gpt.py +++ b/modules/ChatTTS/ChatTTS/model/gpt.py @@ -545,23 +545,35 @@ def generate( if show_tqdm: pbar.close() self.logger.warn("regenerate in order to ensure non-empty") + del_all(attentions) + del_all(hiddens) + del ( + start_idx, + end_idx, + finish, + temperature, + attention_mask_cache, + past_key_values, + idx_next, + inputs_ids_tmp, + ) new_gen = self.generate( - emb, - inputs_ids, - old_temperature, - eos_token, - attention_mask, - max_new_token, - min_new_token, - logits_warpers, - logits_processors, - infer_text, - return_attn, - return_hidden, - stream, - show_tqdm, - ensure_non_empty, - context, + emb=emb, + inputs_ids=inputs_ids, + old_temperature=old_temperature, + eos_token=eos_token, + attention_mask=attention_mask, + max_new_token=max_new_token, + min_new_token=min_new_token, + logits_warpers=logits_warpers, + logits_processors=logits_processors, + infer_text=infer_text, + return_attn=return_attn, + return_hidden=return_hidden, + stream=stream, + show_tqdm=show_tqdm, + ensure_non_empty=ensure_non_empty, + context=context, ) for result in new_gen: yield result diff --git a/modules/api/impl/google_api.py b/modules/api/impl/google_api.py index e5c7172..8d76262 100644 --- a/modules/api/impl/google_api.py +++ b/modules/api/impl/google_api.py @@ -5,12 +5,12 @@ from modules.api import utils as api_utils from modules.api.Api import APIManager -from modules.api.impl.handler.SSMLHandler import SSMLHandler -from modules.api.impl.handler.TTSHandler import TTSHandler -from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat -from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig -from modules.api.impl.model.enhancer_model import EnhancerConfig -from modules.speaker import Speaker, speaker_mgr +from modules.core.handler.datacls.audio_model import AdjustConfig, AudioFormat +from modules.core.handler.datacls.chattts_model import ChatTTSConfig, InferConfig +from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.handler.SSMLHandler import SSMLHandler +from modules.core.handler.TTSHandler import TTSHandler +from modules.core.speaker import Speaker, speaker_mgr class SynthesisInput(BaseModel): diff --git a/modules/api/impl/handler/AudioHandler.py b/modules/api/impl/handler/AudioHandler.py deleted file mode 100644 index a90a9c3..0000000 --- a/modules/api/impl/handler/AudioHandler.py +++ /dev/null @@ -1,139 +0,0 @@ -import base64 -import io -import wave -from typing import AsyncGenerator, Generator - -import numpy as np -from fastapi import Request -from pydub import AudioSegment - -from modules.api.impl.model.audio_model import AudioFormat -from modules.ChatTTSInfer import ChatTTSInfer -from modules.utils.audio import ndarray_to_segment - - -def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000): - wav_buf = io.BytesIO() - with wave.open(wav_buf, "wb") as vfout: - vfout.setnchannels(channels) - vfout.setsampwidth(sample_width) - vfout.setframerate(sample_rate) - vfout.writeframes(frame_input) - wav_buf.seek(0) - return wav_buf.read() - - -wav_header = wave_header_chunk() - - -def read_to_wav(audio_data: np.ndarray, buffer: io.BytesIO): - audio_data = audio_data / np.max(np.abs(audio_data)) - chunk = (audio_data * 32768).astype(np.int16) - buffer.write(chunk.tobytes()) - return buffer - - -def align_audio(audio_data: np.ndarray, channels=1) -> np.ndarray: - samples_per_frame = channels - total_samples = len(audio_data) - aligned_samples = total_samples - (total_samples % samples_per_frame) - return audio_data[:aligned_samples] - - -def pad_audio_frame(audio_data: np.ndarray, frame_size=1152, channels=1) -> np.ndarray: - samples_per_frame = frame_size * channels - padding_needed = ( - samples_per_frame - len(audio_data) % samples_per_frame - ) % samples_per_frame - return np.pad(audio_data, (0, padding_needed), mode="constant") - - -class AudioHandler: - def enqueue(self) -> tuple[np.ndarray, int]: - raise NotImplementedError("Method 'enqueue' must be implemented by subclass") - - def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]: - raise NotImplementedError( - "Method 'enqueue_stream' must be implemented by subclass" - ) - - def encode_audio( - self, audio_data: np.ndarray, sample_rate: int, format: AudioFormat - ) -> io.BytesIO: - buffer = io.BytesIO() - - audio_data = audio_data / np.max(np.abs(audio_data)) - audio_data = (audio_data * 32767).astype(np.int16) - - audio_segment: AudioSegment = ndarray_to_segment( - audio_data, frame_rate=sample_rate - ) - - if format == AudioFormat.mp3: - audio_segment.export(buffer, format="mp3") - buffer.seek(0) - elif format == AudioFormat.wav: - audio_segment.export(buffer, format="wav") - buffer.seek(len(wav_header)) - elif format == AudioFormat.ogg: - # FIXME: 流式输出有 bug,会莫名其妙中断输出... - audio_segment.export(buffer, format="ogg") - buffer.seek(0) - else: - raise ValueError(f"Invalid audio format: {format}") - - return buffer - - def enqueue_to_stream(self, format: AudioFormat) -> Generator[bytes, None, None]: - if format == AudioFormat.wav: - yield wav_header - - for audio_data, sample_rate in self.enqueue_stream(): - yield self.encode_audio(audio_data, sample_rate, format).read() - - # print("AudioHandler: enqueue_to_stream done") - - async def enqueue_to_stream_with_request( - self, request: Request, format: AudioFormat - ) -> AsyncGenerator[bytes, None]: - buffer_gen = self.enqueue_to_stream(format=AudioFormat(format)) - for chunk in buffer_gen: - disconnected = await request.is_disconnected() - if disconnected: - ChatTTSInfer.interrupt() - break - - yield chunk - - # just for test - def enqueue_to_stream_join( - self, format: AudioFormat - ) -> Generator[bytes, None, None]: - if format == AudioFormat.wav: - yield wav_header - - data = None - for audio_data, sample_rate in self.enqueue_stream(): - data = audio_data if data is None else np.concatenate((data, audio_data)) - buffer = self.encode_audio(data, sample_rate, format) - yield buffer.read() - - def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO: - audio_data, sample_rate = self.enqueue() - buffer = self.encode_audio(audio_data, sample_rate, format) - if format == AudioFormat.wav: - buffer = io.BytesIO(wav_header + buffer.read()) - return buffer - - def enqueue_to_bytes(self, format: AudioFormat) -> bytes: - buffer = self.enqueue_to_buffer(format=format) - binary = buffer.read() - return binary - - def enqueue_to_base64(self, format: AudioFormat) -> str: - binary = self.enqueue_to_bytes(format=format) - - base64_encoded = base64.b64encode(binary) - base64_string = base64_encoded.decode("utf-8") - - return base64_string diff --git a/modules/api/impl/handler/SSMLHandler.py b/modules/api/impl/handler/SSMLHandler.py deleted file mode 100644 index 394643a..0000000 --- a/modules/api/impl/handler/SSMLHandler.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Generator - -import numpy as np -from fastapi import HTTPException - -from modules.api.impl.handler.AudioHandler import AudioHandler -from modules.api.impl.model.audio_model import AdjustConfig -from modules.api.impl.model.chattts_model import InferConfig -from modules.api.impl.model.enhancer_model import EnhancerConfig -from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full -from modules.normalization import text_normalize -from modules.ssml_parser.SSMLParser import create_ssml_parser -from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments -from modules.utils import audio - - -class SSMLHandler(AudioHandler): - def __init__( - self, - ssml_content: str, - infer_config: InferConfig, - adjust_config: AdjustConfig, - enhancer_config: EnhancerConfig, - ) -> None: - assert isinstance(ssml_content, str), "ssml_content must be a string." - assert isinstance( - infer_config, InferConfig - ), "infer_config must be an InferConfig object." - assert isinstance( - adjust_config, AdjustConfig - ), "adjest_config should be AdjustConfig" - assert isinstance( - enhancer_config, EnhancerConfig - ), "enhancer_config must be an EnhancerConfig object." - - self.ssml_content = ssml_content - self.infer_config = infer_config - self.adjest_config = adjust_config - self.enhancer_config = enhancer_config - - self.validate() - - def validate(self): - # TODO params checker - pass - - def enqueue(self) -> tuple[np.ndarray, int]: - ssml_content = self.ssml_content - infer_config = self.infer_config - adjust_config = self.adjest_config - enhancer_config = self.enhancer_config - - parser = create_ssml_parser() - segments = parser.parse(ssml_content) - for seg in segments: - seg["text"] = text_normalize(seg["text"], is_end=True) - - if len(segments) == 0: - raise HTTPException( - status_code=422, detail="The SSML text is empty or parsing failed." - ) - - synthesize = SynthesizeSegments( - batch_size=infer_config.batch_size, - eos=infer_config.eos, - spliter_thr=infer_config.spliter_threshold, - ) - audio_segments = synthesize.synthesize_segments(segments) - combined_audio = combine_audio_segments(audio_segments) - - sample_rate, audio_data = audio.pydub_to_np(combined_audio) - - if enhancer_config.enabled: - nfe = enhancer_config.nfe - solver = enhancer_config.solver - lambd = enhancer_config.lambd - tau = enhancer_config.tau - - audio_data, sample_rate = apply_audio_enhance_full( - audio_data=audio_data, - sr=sample_rate, - nfe=nfe, - solver=solver, - lambd=lambd, - tau=tau, - ) - - audio_data = audio.apply_prosody_to_audio_data( - audio_data=audio_data, - rate=adjust_config.speed_rate, - pitch=adjust_config.pitch, - volume=adjust_config.volume_gain_db, - sr=sample_rate, - ) - - if adjust_config.normalize: - sample_rate, audio_data = audio.apply_normalize( - audio_data=audio_data, headroom=adjust_config.headroom, sr=sample_rate - ) - - return audio_data, sample_rate - - def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]: - # TODO: 应该很好支持stream... - raise NotImplementedError("Stream is not supported for SSMLHandler.") diff --git a/modules/api/impl/handler/TTSHandler.py b/modules/api/impl/handler/TTSHandler.py deleted file mode 100644 index d1b3c24..0000000 --- a/modules/api/impl/handler/TTSHandler.py +++ /dev/null @@ -1,156 +0,0 @@ -import logging -from typing import Generator - -import numpy as np - -from modules.api.impl.handler.AudioHandler import AudioHandler -from modules.api.impl.model.audio_model import AdjustConfig -from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig -from modules.api.impl.model.enhancer_model import EnhancerConfig -from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full -from modules.normalization import text_normalize -from modules.speaker import Speaker -from modules.synthesize_audio import synthesize_audio -from modules.synthesize_stream import synthesize_stream -from modules.utils.audio import apply_normalize, apply_prosody_to_audio_data - -logger = logging.getLogger(__name__) - - -class TTSHandler(AudioHandler): - def __init__( - self, - text_content: str, - spk: Speaker, - tts_config: ChatTTSConfig, - infer_config: InferConfig, - adjust_config: AdjustConfig, - enhancer_config: EnhancerConfig, - ): - assert isinstance(text_content, str), "text_content should be str" - assert isinstance(spk, Speaker), "spk should be Speaker" - assert isinstance( - tts_config, ChatTTSConfig - ), "tts_config should be ChatTTSConfig" - assert isinstance( - infer_config, InferConfig - ), "infer_config should be InferConfig" - assert isinstance( - adjust_config, AdjustConfig - ), "adjest_config should be AdjustConfig" - assert isinstance( - enhancer_config, EnhancerConfig - ), "enhancer_config should be EnhancerConfig" - - self.text_content = text_content - self.spk = spk - self.tts_config = tts_config - self.infer_config = infer_config - self.adjest_config = adjust_config - self.enhancer_config = enhancer_config - - self.validate() - - def validate(self): - # TODO params checker - pass - - def enqueue(self) -> tuple[np.ndarray, int]: - text = text_normalize(self.text_content) - tts_config = self.tts_config - infer_config = self.infer_config - adjust_config = self.adjest_config - enhancer_config = self.enhancer_config - - sample_rate, audio_data = synthesize_audio( - text, - spk=self.spk, - temperature=tts_config.temperature, - top_P=tts_config.top_p, - top_K=tts_config.top_k, - prompt1=tts_config.prompt1, - prompt2=tts_config.prompt2, - prefix=tts_config.prefix, - infer_seed=infer_config.seed, - batch_size=infer_config.batch_size, - spliter_threshold=infer_config.spliter_threshold, - end_of_sentence=infer_config.eos, - ) - - if enhancer_config.enabled: - nfe = enhancer_config.nfe - solver = enhancer_config.solver - lambd = enhancer_config.lambd - tau = enhancer_config.tau - - audio_data, sample_rate = apply_audio_enhance_full( - audio_data=audio_data, - sr=sample_rate, - nfe=nfe, - solver=solver, - lambd=lambd, - tau=tau, - ) - - audio_data = apply_prosody_to_audio_data( - audio_data=audio_data, - rate=adjust_config.speed_rate, - pitch=adjust_config.pitch, - volume=adjust_config.volume_gain_db, - sr=sample_rate, - ) - - if adjust_config.normalize: - sample_rate, audio_data = apply_normalize( - audio_data=audio_data, - headroom=adjust_config.headroom, - sr=sample_rate, - ) - - return audio_data, sample_rate - - def enqueue_stream(self) -> Generator[tuple[np.ndarray, int], None, None]: - text = text_normalize(self.text_content) - tts_config = self.tts_config - infer_config = self.infer_config - adjust_config = self.adjest_config - enhancer_config = self.enhancer_config - - if enhancer_config.enabled: - logger.warning( - "enhancer_config is enabled, but it is not supported in stream mode" - ) - - gen = synthesize_stream( - text, - spk=self.spk, - temperature=tts_config.temperature, - top_P=tts_config.top_p, - top_K=tts_config.top_k, - prompt1=tts_config.prompt1, - prompt2=tts_config.prompt2, - prefix=tts_config.prefix, - infer_seed=infer_config.seed, - spliter_threshold=infer_config.spliter_threshold, - end_of_sentence=infer_config.eos, - ) - - # FIXME: 很奇怪,合并出来的音频每个 chunk 之前会有一段异常,暂时没有查出来是哪里的问题,可能是解码时候切割漏了?或者多了? - for sr, wav in gen: - - wav = apply_prosody_to_audio_data( - audio_data=wav, - rate=adjust_config.speed_rate, - pitch=adjust_config.pitch, - volume=adjust_config.volume_gain_db, - sr=sr, - ) - - if adjust_config.normalize: - sr, wav = apply_normalize( - audio_data=wav, - headroom=adjust_config.headroom, - sr=sr, - ) - - yield wav, sr diff --git a/modules/api/impl/models_api.py b/modules/api/impl/models_api.py index 7d03ad1..5300740 100644 --- a/modules/api/impl/models_api.py +++ b/modules/api/impl/models_api.py @@ -1,18 +1,22 @@ from modules.api import utils as api_utils from modules.api.Api import APIManager from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer -from modules.models import reload_chat_tts, unload_chat_tts +from modules.core.models import zoo def setup(app: APIManager): + """ + TODO: 需要增强 zoo 以支持多 models 管理 + """ + @app.get("/v1/models/reload", response_model=api_utils.BaseResponse) async def reload_models(): - reload_chat_tts() + zoo.ChatTTS.reload_chat_tts() reload_enhancer() return api_utils.success_response("Models reloaded") @app.get("/v1/models/unload", response_model=api_utils.BaseResponse) async def reload_models(): - unload_chat_tts() + zoo.ChatTTS.unload_chat_tts() unload_enhancer() return api_utils.success_response("Models unloaded") diff --git a/modules/api/impl/openai_api.py b/modules/api/impl/openai_api.py index bc6cec3..a98fa17 100644 --- a/modules/api/impl/openai_api.py +++ b/modules/api/impl/openai_api.py @@ -7,12 +7,12 @@ from modules.api import utils as api_utils from modules.api.Api import APIManager -from modules.api.impl.handler.TTSHandler import TTSHandler -from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat -from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig -from modules.api.impl.model.enhancer_model import EnhancerConfig +from modules.core.handler.datacls.audio_model import AdjustConfig, AudioFormat +from modules.core.handler.datacls.chattts_model import ChatTTSConfig, InferConfig +from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.handler.TTSHandler import TTSHandler +from modules.core.speaker import Speaker, speaker_mgr from modules.data import styles_mgr -from modules.speaker import Speaker, speaker_mgr class AudioSpeechRequest(BaseModel): @@ -92,6 +92,7 @@ async def openai_speech_api( spliter_threshold=spliter_threshold, eos=eos, seed=seed, + stream=stream, ) adjust_config = AdjustConfig(speaking_rate=speed) enhancer_config = EnhancerConfig( diff --git a/modules/api/impl/refiner_api.py b/modules/api/impl/refiner_api.py index b205db8..be64bd0 100644 --- a/modules/api/impl/refiner_api.py +++ b/modules/api/impl/refiner_api.py @@ -4,7 +4,7 @@ from modules import refiner from modules.api import utils as api_utils from modules.api.Api import APIManager -from modules.normalization import text_normalize +from modules.core.tn.ChatTtsTN import ChatTtsTN class RefineTextRequest(BaseModel): @@ -16,8 +16,8 @@ class RefineTextRequest(BaseModel): temperature: float = 0.7 repetition_penalty: float = 1.0 max_new_token: int = 384 - normalize: bool = True spliter_threshold: int = 300 + normalize: bool = True async def refiner_prompt_post(request: RefineTextRequest): @@ -28,8 +28,8 @@ async def refiner_prompt_post(request: RefineTextRequest): try: text = request.text if request.normalize: - text = text_normalize(request.text) - # TODO 其实这里可以做 spliter 和 batch 处理 + text = ChatTtsTN.normalize(request.text) + # TODO 需要迁移使用 refiner model refined_text = refiner.refine_text( text=text, prompt=request.prompt, diff --git a/modules/api/impl/speaker_api.py b/modules/api/impl/speaker_api.py index 96e7dc4..a3041ef 100644 --- a/modules/api/impl/speaker_api.py +++ b/modules/api/impl/speaker_api.py @@ -4,7 +4,7 @@ from modules.api import utils as api_utils from modules.api.Api import APIManager -from modules.speaker import speaker_mgr +from modules.core.speaker import speaker_mgr class CreateSpeaker(BaseModel): diff --git a/modules/api/impl/ssml_api.py b/modules/api/impl/ssml_api.py index eaabf10..f5c6711 100644 --- a/modules/api/impl/ssml_api.py +++ b/modules/api/impl/ssml_api.py @@ -3,10 +3,10 @@ from pydantic import BaseModel from modules.api.Api import APIManager -from modules.api.impl.handler.SSMLHandler import SSMLHandler -from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat -from modules.api.impl.model.chattts_model import InferConfig -from modules.api.impl.model.enhancer_model import EnhancerConfig +from modules.core.handler.datacls.audio_model import AdjustConfig, AudioFormat +from modules.core.handler.datacls.chattts_model import InferConfig +from modules.core.handler.SSMLHandler import SSMLHandler +from modules.core.handler.datacls.enhancer_model import EnhancerConfig class SSMLRequest(BaseModel): @@ -24,6 +24,8 @@ class SSMLRequest(BaseModel): enhancer: EnhancerConfig = EnhancerConfig() adjuster: AdjustConfig = AdjustConfig() + stream: bool = False + async def synthesize_ssml_api( request: SSMLRequest = Body( @@ -35,6 +37,7 @@ async def synthesize_ssml_api( format = request.format.lower() batch_size = request.batch_size eos = request.eos + stream = request.stream spliter_thr = request.spliter_thr enhancer = request.enhancer adjuster = request.adjuster @@ -58,9 +61,7 @@ async def synthesize_ssml_api( ) infer_config = InferConfig( - batch_size=batch_size, - spliter_threshold=spliter_thr, - eos=eos, + batch_size=batch_size, spliter_threshold=spliter_thr, eos=eos, stream=stream ) adjust_config = adjuster enhancer_config = enhancer @@ -72,12 +73,18 @@ async def synthesize_ssml_api( enhancer_config=enhancer_config, ) - buffer = handler.enqueue_to_buffer(format=request.format) - - mime_type = f"audio/{format}" + media_type = f"audio/{format}" if format == AudioFormat.mp3: - mime_type = "audio/mpeg" - return StreamingResponse(buffer, media_type=mime_type) + media_type = "audio/mpeg" + + if stream: + gen = handler.enqueue_to_stream_with_request( + request=request, format=AudioFormat(format) + ) + return StreamingResponse(gen, media_type=media_type) + else: + buffer = handler.enqueue_to_buffer(format=AudioFormat(format)) + return StreamingResponse(buffer, media_type=media_type) except Exception as e: import logging diff --git a/modules/api/impl/tts_api.py b/modules/api/impl/tts_api.py index b20f970..611b825 100644 --- a/modules/api/impl/tts_api.py +++ b/modules/api/impl/tts_api.py @@ -1,5 +1,6 @@ import io import logging +from typing import Literal, Union from fastapi import Depends, HTTPException, Query, Request from fastapi.responses import FileResponse, StreamingResponse @@ -7,11 +8,11 @@ from modules.api import utils as api_utils from modules.api.Api import APIManager -from modules.api.impl.handler.TTSHandler import TTSHandler -from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat -from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig -from modules.api.impl.model.enhancer_model import EnhancerConfig -from modules.speaker import Speaker +from modules.core.handler.datacls.audio_model import AdjustConfig, AudioFormat +from modules.core.handler.datacls.chattts_model import ChatTTSConfig, InferConfig +from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.handler.TTSHandler import TTSHandler +from modules.core.speaker import Speaker logger = logging.getLogger(__name__) @@ -51,6 +52,10 @@ class TTSParams(BaseModel): stream: bool = Query(False, description="Stream the audio") + no_cache: Union[bool, Literal["on", "off"]] = Query( + False, description="Disable cache" + ) + async def synthesize_tts(request: Request, params: TTSParams = Depends()): try: @@ -102,6 +107,12 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()): prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1) prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2) eos = params.eos or "" + stream = params.stream + no_cache = ( + params.no_cache + if isinstance(params.no_cache, bool) + else params.no_cache == "on" + ) batch_size = int(params.bs) threshold = int(params.thr) @@ -120,6 +131,8 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()): spliter_threshold=threshold, eos=eos, seed=seed, + stream=stream, + no_cache=no_cache, ) adjust_config = AdjustConfig( pitch=params.pitch, @@ -143,13 +156,7 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()): if params.format == "mp3": media_type = "audio/mpeg" - if params.stream: - if infer_config.batch_size != 1: - # 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略 - logger.warning( - f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1" - ) - + if stream: gen = handler.enqueue_to_stream_with_request( request=request, format=AudioFormat(params.format) ) diff --git a/modules/api/impl/xtts_v2_api.py b/modules/api/impl/xtts_v2_api.py index bae9b45..191c952 100644 --- a/modules/api/impl/xtts_v2_api.py +++ b/modules/api/impl/xtts_v2_api.py @@ -5,11 +5,11 @@ from pydantic import BaseModel from modules.api.Api import APIManager -from modules.api.impl.handler.TTSHandler import TTSHandler -from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat -from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig -from modules.api.impl.model.enhancer_model import EnhancerConfig -from modules.speaker import speaker_mgr +from modules.core.handler.datacls.audio_model import AdjustConfig, AudioFormat +from modules.core.handler.datacls.chattts_model import ChatTTSConfig, InferConfig +from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.handler.TTSHandler import TTSHandler +from modules.core.speaker import speaker_mgr logger = logging.getLogger(__name__) @@ -111,6 +111,7 @@ async def tts_to_audio(request: SynthesisRequest): spliter_threshold=XTTSV2.spliter_threshold, eos=XTTSV2.eos, seed=XTTSV2.infer_seed, + stream=False, ) adjust_config = AdjustConfig( speed_rate=XTTSV2.speed, @@ -164,6 +165,7 @@ async def tts_stream( spliter_threshold=XTTSV2.spliter_threshold, eos=XTTSV2.eos, seed=XTTSV2.infer_seed, + stream=True, ) adjust_config = AdjustConfig( speed_rate=XTTSV2.speed, diff --git a/modules/api/utils.py b/modules/api/utils.py index ba06e7e..b06a388 100644 --- a/modules/api/utils.py +++ b/modules/api/utils.py @@ -1,11 +1,11 @@ -from typing import Any, Union +from typing import Any, Dict, Union +import numpy as np from pydantic import BaseModel from pydub import AudioSegment +from modules.core.speaker import speaker_mgr from modules.data import styles_mgr -from modules.speaker import speaker_mgr -from modules.ssml import merge_prompt class ParamsTypeError(Exception): @@ -36,6 +36,27 @@ def to_number(value, t, default=0): return default +def merge_prompt(attrs: dict, elem: Dict[str, Any]): + + def attr_num(attrs: Dict[str, Any], k: str, min_value: int, max_value: int): + val = elem.get(k, attrs.get(k, "")) + if val == "": + return + if val == "max": + val = max_value + if val == "min": + val = min_value + val = np.clip(int(val), min_value, max_value) + if "prefix" not in attrs or attrs["prefix"] == None: + attrs["prefix"] = "" + attrs["prefix"] += " " + f"[{k}_{val}]" + + attr_num(attrs, "oral", 0, 9) + attr_num(attrs, "speed", 0, 9) + attr_num(attrs, "laugh", 0, 2) + attr_num(attrs, "break", 0, 7) + + def calc_spk_style(spk: Union[str, int], style: Union[str, int]): voice_attrs = { "spk": None, diff --git a/modules/ssml_parser/__init__.py b/modules/core/__init__.py similarity index 100% rename from modules/ssml_parser/__init__.py rename to modules/core/__init__.py diff --git a/modules/core/handler/AudioHandler.py b/modules/core/handler/AudioHandler.py new file mode 100644 index 0000000..953cebb --- /dev/null +++ b/modules/core/handler/AudioHandler.py @@ -0,0 +1,143 @@ +import base64 +import io +import wave +from typing import AsyncGenerator, Generator + +import numpy as np +from fastapi import Request + +from modules.core.handler.encoder.StreamEncoder import StreamEncoder +from modules.core.handler.encoder.WavFile import WAVFileBytes +from modules.core.handler.datacls.audio_model import AudioFormat +from modules.core.handler.encoder.encoders import ( + AacEncoder, + FlacEncoder, + Mp3Encoder, + OggEncoder, + WavEncoder, +) +from modules.core.models.zoo.ChatTTSInfer import ChatTTSInfer +from modules.core.pipeline.processor import NP_AUDIO + + +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000): + wav_buf = io.BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + wav_buf.seek(0) + return wav_buf.read() + + +# NOTE: 这个可能只适合 chattts +wav_header = wave_header_chunk() + + +def remove_wav_bytes_header(wav_bytes: bytes): + wav_file = WAVFileBytes(wav_bytes=wav_bytes) + wav_file.read() + return wav_file.get_body_data() + + +def read_np_to_wav(audio_data: np.ndarray) -> bytes: + audio_data: np.ndarray = audio_data / np.max(np.abs(audio_data)) + audio_data = (audio_data * 32767).astype(np.int16) + return audio_data.tobytes() + + +class AudioHandler: + def enqueue(self) -> NP_AUDIO: + raise NotImplementedError("Method 'enqueue' must be implemented by subclass") + + def enqueue_stream(self) -> Generator[NP_AUDIO, None, None]: + raise NotImplementedError( + "Method 'enqueue_stream' must be implemented by subclass" + ) + + def get_encoder(self, format: AudioFormat) -> StreamEncoder: + # TODO 这里可以增加 编码器配置 + if format == AudioFormat.wav: + encoder = WavEncoder() + elif format == AudioFormat.mp3: + encoder = Mp3Encoder() + elif format == AudioFormat.flac: + encoder = FlacEncoder() + # OGG 和 ACC 编码有问题,不知道为啥 + # FIXME: BrokenPipeError: [Errno 32] Broken pipe + elif format == AudioFormat.acc: + encoder = AacEncoder() + # FIXME: BrokenPipeError: [Errno 32] Broken pipe + elif format == AudioFormat.ogg: + encoder = OggEncoder() + else: + raise ValueError(f"Unsupported audio format: {format}") + encoder.open() + encoder.write(wav_header) + + return encoder + + def enqueue_to_stream(self, format: AudioFormat) -> Generator[bytes, None, None]: + encoder = self.get_encoder(format) + chunk_data = bytes() + # NOTE sample_rate 写在文件头里了所以用不到 + for sample_rate, audio_data in self.enqueue_stream(): + audio_bytes = read_np_to_wav(audio_data=audio_data) + encoder.write(audio_bytes) + chunk_data = encoder.read() + while len(chunk_data) > 0: + yield chunk_data + chunk_data = encoder.read() + + encoder.close() + while len(chunk_data) > 0: + yield chunk_data + chunk_data = encoder.read() + + async def enqueue_to_stream_with_request( + self, request: Request, format: AudioFormat + ) -> AsyncGenerator[bytes, None]: + for chunk in self.enqueue_to_stream(format=AudioFormat(format)): + disconnected = await request.is_disconnected() + if disconnected: + # TODO: 这个逻辑应该传递给 zoo + ChatTTSInfer.interrupt() + break + yield chunk + + # just for test + def enqueue_to_stream_join( + self, format: AudioFormat + ) -> Generator[bytes, None, None]: + encoder = self.get_encoder(format) + chunk_data = bytes() + for sample_rate, audio_data in self.enqueue_stream(): + audio_bytes = read_np_to_wav(audio_data=audio_data) + encoder.write(audio_bytes) + chunk_data = encoder.read() + + encoder.close() + while len(chunk_data) > 0: + yield chunk_data + chunk_data = encoder.read() + + def enqueue_to_bytes(self, format: AudioFormat) -> bytes: + encoder = self.get_encoder(format) + sample_rate, audio_data = self.enqueue() + audio_bytes = read_np_to_wav(audio_data=audio_data) + encoder.write(audio_bytes) + encoder.close() + return encoder.read_all() + + def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO: + audio_bytes = self.enqueue_to_bytes(format=format) + return io.BytesIO(audio_bytes) + + def enqueue_to_base64(self, format: AudioFormat) -> str: + binary = self.enqueue_to_bytes(format=format) + + base64_encoded = base64.b64encode(binary) + base64_string = base64_encoded.decode("utf-8") + + return base64_string diff --git a/modules/core/handler/SSMLHandler.py b/modules/core/handler/SSMLHandler.py new file mode 100644 index 0000000..4198462 --- /dev/null +++ b/modules/core/handler/SSMLHandler.py @@ -0,0 +1,64 @@ +from typing import Generator + + +from modules.core.handler.AudioHandler import AudioHandler +from modules.core.handler.datacls.audio_model import AdjustConfig +from modules.core.handler.datacls.chattts_model import ChatTTSConfig, InferConfig +from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.pipeline.factory import PipelineFactory +from modules.core.pipeline.processor import NP_AUDIO, TTSPipelineContext + + +class SSMLHandler(AudioHandler): + def __init__( + self, + ssml_content: str, + infer_config: InferConfig, + adjust_config: AdjustConfig, + enhancer_config: EnhancerConfig, + ) -> None: + assert isinstance(ssml_content, str), "ssml_content must be a string." + assert isinstance( + infer_config, InferConfig + ), "infer_config must be an InferConfig object." + assert isinstance( + adjust_config, AdjustConfig + ), "adjest_config should be AdjustConfig" + assert isinstance( + enhancer_config, EnhancerConfig + ), "enhancer_config must be an EnhancerConfig object." + + self.ssml_content = ssml_content + self.infer_config = infer_config + self.adjest_config = adjust_config + self.enhancer_config = enhancer_config + + self.validate() + + def validate(self): + # TODO params checker + pass + + def create_pipeline(self): + ssml_content = self.ssml_content + infer_config = self.infer_config + adjust_config = self.adjest_config + enhancer_config = self.enhancer_config + + ctx = TTSPipelineContext( + ssml=ssml_content, + tts_config=ChatTTSConfig(), + infer_config=infer_config, + adjust_config=adjust_config, + enhancer_config=enhancer_config, + ) + pipeline = PipelineFactory.create(ctx) + return pipeline + + def enqueue(self) -> NP_AUDIO: + pipeline = self.create_pipeline() + return pipeline.generate() + + def enqueue_stream(self) -> Generator[NP_AUDIO, None, None]: + pipeline = self.create_pipeline() + return pipeline.generate_stream() diff --git a/modules/core/handler/TTSHandler.py b/modules/core/handler/TTSHandler.py new file mode 100644 index 0000000..84e7eb7 --- /dev/null +++ b/modules/core/handler/TTSHandler.py @@ -0,0 +1,80 @@ +import logging +from typing import Generator + + +from modules.core.handler.AudioHandler import AudioHandler +from modules.core.handler.datacls.audio_model import AdjustConfig +from modules.core.handler.datacls.chattts_model import ChatTTSConfig, InferConfig +from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.pipeline.dcls import TTSPipelineContext +from modules.core.pipeline.factory import PipelineFactory +from modules.core.pipeline.processor import NP_AUDIO +from modules.core.speaker import Speaker + +logger = logging.getLogger(__name__) + + +class TTSHandler(AudioHandler): + def __init__( + self, + text_content: str, + spk: Speaker, + tts_config: ChatTTSConfig, + infer_config: InferConfig, + adjust_config: AdjustConfig, + enhancer_config: EnhancerConfig, + ): + assert isinstance(text_content, str), "text_content should be str" + assert isinstance(spk, Speaker), "spk should be Speaker" + assert isinstance( + tts_config, ChatTTSConfig + ), "tts_config should be ChatTTSConfig" + assert isinstance( + infer_config, InferConfig + ), "infer_config should be InferConfig" + assert isinstance( + adjust_config, AdjustConfig + ), "adjest_config should be AdjustConfig" + assert isinstance( + enhancer_config, EnhancerConfig + ), "enhancer_config should be EnhancerConfig" + + self.text_content = text_content + self.spk = spk + self.tts_config = tts_config + self.infer_config = infer_config + self.adjest_config = adjust_config + self.enhancer_config = enhancer_config + + self.validate() + + def validate(self): + # TODO params checker + pass + + def create_pipeline(self): + text_content = self.text_content + infer_config = self.infer_config + tts_config = self.tts_config + adjust_config = self.adjest_config + enhancer_config = self.enhancer_config + spk = self.spk + + ctx = TTSPipelineContext( + text=text_content, + spk=spk, + tts_config=tts_config, + infer_config=infer_config, + adjust_config=adjust_config, + enhancer_config=enhancer_config, + ) + pipeline = PipelineFactory.create(ctx) + return pipeline + + def enqueue(self) -> NP_AUDIO: + pipeline = self.create_pipeline() + return pipeline.generate() + + def enqueue_stream(self) -> Generator[NP_AUDIO, None, None]: + pipeline = self.create_pipeline() + return pipeline.generate_stream() diff --git a/modules/api/impl/model/audio_model.py b/modules/core/handler/datacls/audio_model.py similarity index 90% rename from modules/api/impl/model/audio_model.py rename to modules/core/handler/datacls/audio_model.py index af4bc73..49eb9db 100644 --- a/modules/api/impl/model/audio_model.py +++ b/modules/core/handler/datacls/audio_model.py @@ -7,6 +7,8 @@ class AudioFormat(str, Enum): mp3 = "mp3" wav = "wav" ogg = "ogg" + acc = "acc" + flac = "flac" class AdjustConfig(BaseModel): diff --git a/modules/api/impl/model/chattts_model.py b/modules/core/handler/datacls/chattts_model.py similarity index 62% rename from modules/api/impl/model/chattts_model.py rename to modules/core/handler/datacls/chattts_model.py index 10eaa46..3d38cc6 100644 --- a/modules/api/impl/model/chattts_model.py +++ b/modules/core/handler/datacls/chattts_model.py @@ -2,10 +2,14 @@ class ChatTTSConfig(BaseModel): + # model id + mid: str = "chat-tts" + style: str = "" temperature: float = 0.3 top_p: float = 0.7 top_k: int = 20 + prompt: str = "" prompt1: str = "" prompt2: str = "" prefix: str = "" @@ -17,3 +21,11 @@ class InferConfig(BaseModel): # end_of_sentence eos: str = "[uv_break]" seed: int = 42 + + stream: bool = False + stream_chunk_size: int = 96 + + no_cache: bool = False + + # 开启同步生成 (主要是给gradio用) + sync_gen: bool = False diff --git a/modules/api/impl/model/enhancer_model.py b/modules/core/handler/datacls/enhancer_model.py similarity index 100% rename from modules/api/impl/model/enhancer_model.py rename to modules/core/handler/datacls/enhancer_model.py diff --git a/modules/core/handler/datacls/tn_model.py b/modules/core/handler/datacls/tn_model.py new file mode 100644 index 0000000..868bb2b --- /dev/null +++ b/modules/core/handler/datacls/tn_model.py @@ -0,0 +1,8 @@ +from typing import Optional + +from pydantic import BaseModel + + +class TNConfig(BaseModel): + enabled: Optional[list[str]] = None + disabled: Optional[list[str]] = None diff --git a/modules/core/handler/encoder/StreamEncoder.py b/modules/core/handler/encoder/StreamEncoder.py new file mode 100644 index 0000000..6431a11 --- /dev/null +++ b/modules/core/handler/encoder/StreamEncoder.py @@ -0,0 +1,82 @@ +import pydub +import pydub.utils +import subprocess +import threading +import queue +import io + + +class StreamEncoder: + def __init__(self) -> None: + self.encoder = pydub.utils.get_encoder_name() + self.p: subprocess.Popen = None + self.output_queue = queue.Queue() + self.read_thread = None + self.chunk_size = 1024 + + def open( + self, format: str = "mp3", acodec: str = "libmp3lame", bitrate: str = "320k" + ): + encoder = self.encoder + self.p = subprocess.Popen( + [ + encoder, + "-f", + "wav", + "-i", + "pipe:0", + "-f", + format, + "-acodec", + acodec, + "-b:a", + bitrate, + "-", + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + self.read_thread = threading.Thread(target=self._read_output) + self.read_thread.daemon = True + self.read_thread.start() + + def _read_output(self): + while self.p: + data = self.p.stdout.read(self.chunk_size) + if not data: + break + self.output_queue.put(data) + + def write(self, data: bytes): + if self.p is None: + raise Exception("Encoder is not open") + self.p.stdin.write(data) + self.p.stdin.flush() + + def read(self, timeout=0.1) -> bytes: + if self.p is None: + raise Exception("Encoder is not open") + try: + return self.output_queue.get(timeout=timeout) + except queue.Empty: + return b"" + + def read_all(self, timeout=5) -> bytes: + if self.p is None: + raise Exception("Encoder is not open") + data = b"" + while True: + try: + data += self.output_queue.get(timeout=timeout) + except queue.Empty: + break + return data + + def close(self): + if self.p is None: + return + self.p.stdin.close() + self.p.wait() + if self.read_thread: + self.read_thread.join() diff --git a/modules/core/handler/encoder/WavFile.py b/modules/core/handler/encoder/WavFile.py new file mode 100644 index 0000000..44efed8 --- /dev/null +++ b/modules/core/handler/encoder/WavFile.py @@ -0,0 +1,104 @@ +import struct +import io +import logging + + +class WAVFileBytes: + + def __init__(self, wav_bytes): + self.wav_bytes = wav_bytes + self.riff = None + self.size = None + self.fformat = None + self.aformat = None + self.channels = None + self.samplerate = None + self.bitrate = None + self.subchunks = [] + self.header_end = 0 + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + + def read(self): + with io.BytesIO(self.wav_bytes) as fh: + self.riff, self.size, self.fformat = struct.unpack("<4sI4s", fh.read(12)) + logging.info( + "Riff: %s, Chunk Size: %i, format: %s", + self.riff, + self.size, + self.fformat, + ) + + # Read header + chunk_header = fh.read(8) + subchunkid, subchunksize = struct.unpack("<4sI", chunk_header) + + if subchunkid == b"fmt ": + ( + self.aformat, + self.channels, + self.samplerate, + byterate, + blockalign, + bps, + ) = struct.unpack("HHIIHH", fh.read(16)) + self.bitrate = (self.samplerate * self.channels * bps) / 1024 + logging.info( + "Format: %i, Channels %i, Sample Rate: %i, Kbps: %i", + self.aformat, + self.channels, + self.samplerate, + self.bitrate, + ) + + chunkOffset = fh.tell() + while chunkOffset < self.size: + fh.seek(chunkOffset) + subchunk2id, subchunk2size = struct.unpack("<4sI", fh.read(8)) + logging.info("chunk id: %s, size: %i", subchunk2id, subchunk2size) + subchunk_data = {"id": subchunk2id, "size": subchunk2size} + + if subchunk2id == b"LIST": + listtype = struct.unpack("<4s", fh.read(4)) + subchunk_data["listtype"] = listtype + logging.info( + "\tList Type: %s, List Size: %i", listtype, subchunk2size + ) + + listOffset = 0 + list_items = [] + while (subchunk2size - 8) >= listOffset: + listitemid, listitemsize = struct.unpack("<4sI", fh.read(8)) + listOffset = listOffset + listitemsize + 8 + listdata = fh.read(listitemsize) + list_items.append( + { + "id": listitemid.decode("ascii"), + "size": listitemsize, + "data": listdata.decode("ascii"), + } + ) + logging.info( + "\tList id %s, size: %i, data: %s", + listitemid.decode("ascii"), + listitemsize, + listdata.decode("ascii"), + ) + logging.info("\tOffset: %i", listOffset) + subchunk_data["items"] = list_items + elif subchunk2id == b"data": + logging.info("Found data") + else: + subchunk_data["data"] = fh.read(subchunk2size).decode("ascii") + logging.info("Data: %s", subchunk_data["data"]) + + self.subchunks.append(subchunk_data) + chunkOffset = chunkOffset + subchunk2size + 8 + + self.header_end = fh.tell() + + def get_header_data(self): + return self.wav_bytes[: self.header_end] + + def get_body_data(self): + return self.wav_bytes[self.header_end :] diff --git a/modules/core/handler/encoder/encoders.py b/modules/core/handler/encoder/encoders.py new file mode 100644 index 0000000..2453e37 --- /dev/null +++ b/modules/core/handler/encoder/encoders.py @@ -0,0 +1,31 @@ +from modules.core.handler.encoder.StreamEncoder import StreamEncoder + + +class Mp3Encoder(StreamEncoder): + + def open(self, acodec: str = "libmp3lame", bitrate: str = "128k"): + return super().open("mp3", acodec, bitrate) + + +class WavEncoder(StreamEncoder): + + def open(self, acodec: str = "pcm_s16le", bitrate: str = "128k"): + return super().open("wav", acodec, bitrate) + + +class OggEncoder(StreamEncoder): + + def open(self, acodec: str = "libvorbis", bitrate: str = "128k"): + return super().open("ogg", acodec, bitrate) + + +class FlacEncoder(StreamEncoder): + + def open(self, acodec: str = "flac", bitrate: str = "128k"): + return super().open("flac", acodec, bitrate) + + +class AacEncoder(StreamEncoder): + + def open(self, acodec: str = "aac", bitrate: str = "128k"): + return super().open("aac", acodec, bitrate) diff --git a/tests/test_enhance/test_denoiser.py b/modules/core/models/RefinerModel.py similarity index 100% rename from tests/test_enhance/test_denoiser.py rename to modules/core/models/RefinerModel.py diff --git a/modules/core/models/TTSModel.py b/modules/core/models/TTSModel.py new file mode 100644 index 0000000..e112917 --- /dev/null +++ b/modules/core/models/TTSModel.py @@ -0,0 +1,40 @@ +from typing import Generator, Optional + +import numpy as np + +from modules.core.pipeline.dcls import TTSSegment +from modules.core.pipeline.processor import NP_AUDIO, TTSPipelineContext +from modules.core.tn.TNPipeline import TNPipeline + + +class TTSModel: + + def __init__(self, name: str) -> None: + self.name = name + self.hash = "" + self.tn: Optional[TNPipeline] = None + + def load(self, context: TTSPipelineContext) -> None: + pass + + def unload(self, context: TTSPipelineContext) -> None: + pass + + def generate(self, segment: TTSSegment, context: TTSPipelineContext) -> NP_AUDIO: + return self.generate_batch([segment], context=context)[0] + + def generate_batch( + self, segments: list[TTSSegment], context: TTSPipelineContext + ) -> list[NP_AUDIO]: + raise NotImplementedError("generate_batch method is not implemented") + + def generate_stream( + self, segment: TTSSegment, context: TTSPipelineContext + ) -> Generator[NP_AUDIO, None, None]: + for batch in self.generate_batch_stream([segment], context=context): + yield batch[0] + + def generate_batch_stream( + self, segments: list[TTSSegment], context: TTSPipelineContext + ) -> Generator[list[NP_AUDIO], None, None]: + raise NotImplementedError("generate_batch_stream method is not implemented") diff --git a/modules/core/models/__init__.py b/modules/core/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/models/refiner/OaiRefiner.py b/modules/core/models/refiner/OaiRefiner.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/models/tts/ChatTtsModel.py b/modules/core/models/tts/ChatTtsModel.py new file mode 100644 index 0000000..31897c9 --- /dev/null +++ b/modules/core/models/tts/ChatTtsModel.py @@ -0,0 +1,198 @@ +from typing import Any, Generator, Union + +import numpy as np + +from modules.core.models.TTSModel import TTSModel +from modules.core.models.zoo.ChatTTS import ChatTTS, load_chat_tts, unload_chat_tts +from modules.core.models.zoo.ChatTTSInfer import ChatTTSInfer +from modules.core.models.tts.InerCache import InferCache +from modules.core.pipeline.dcls import TTSPipelineContext +from modules.core.pipeline.pipeline import TTSSegment +from modules.core.pipeline.processor import NP_AUDIO +from modules.utils.SeedContext import SeedContext + + +class ChatTTSModel(TTSModel): + model_id = "chat-tts" + + def __init__(self) -> None: + super().__init__("chat-tts") + self.chat: ChatTTS = None + + def load(self, context: TTSPipelineContext) -> ChatTTS: + self.chat = load_chat_tts() + return self.chat + + def unload(self, context: TTSPipelineContext) -> None: + unload_chat_tts(self.chat) + self.chat = None + + def generate_batch( + self, segments: list[TTSSegment], context: TTSPipelineContext + ) -> list[NP_AUDIO]: + return self.generate_batch_base(segments, context, stream=False) + + def generate_batch_stream( + self, segments: list[TTSSegment], context: TTSPipelineContext + ) -> Generator[list[NP_AUDIO], Any, None]: + return self.generate_batch_base(segments, context, stream=True) + + def get_infer(self, context: TTSPipelineContext): + return ChatTTSInfer(self.load(context=context)) + + def get_cache_kwargs(self, segments: list[TTSSegment], context: TTSPipelineContext): + texts = [segment.text for segment in segments] + + seg0 = segments[0] + spk_emb = seg0.spk.emb if seg0.spk else None + top_P = seg0.top_p + top_K = seg0.top_k + temperature = seg0.temperature + # repetition_penalty = seg0.repetition_penalty + # max_new_token = seg0.max_new_token + prompt1 = seg0.prompt1 + prompt2 = seg0.prompt2 + prefix = seg0.prefix + # use_decoder = seg0.use_decoder + seed = seg0.infer_seed + chunk_size = context.infer_config.stream_chunk_size + + kwargs = dict( + text="|".join(texts), + spk_emb=spk_emb, + top_P=top_P, + top_K=top_K, + temperature=temperature, + repetition_penalty=None, + max_new_token=None, + prompt1=prompt1, + prompt2=prompt2, + prefix=prefix, + stream_chunk_size=chunk_size, + seed=seed, + ) + return kwargs + + def get_cache( + self, segments: list[TTSSegment], context: TTSPipelineContext + ) -> Union[list[NP_AUDIO], None]: + no_cache = context.infer_config.no_cache + if no_cache: + return None + + is_random_generate = context.infer_config.seed == -1 + if is_random_generate: + return None + + kwargs = self.get_cache_kwargs(segments=segments, context=context) + + if InferCache.get_cache_val(model_id=self.model_id, **kwargs): + return InferCache.get_cache_val(model_id=self.model_id, **kwargs) + + return None + + def set_cache( + self, + segments: list[TTSSegment], + context: TTSPipelineContext, + value: list[NP_AUDIO], + ): + no_cache = context.infer_config.no_cache + if no_cache: + return + + kwargs = self.get_cache_kwargs(segments=segments, context=context) + InferCache.set_cache_val(model_id=self.model_id, value=value, **kwargs) + + def generate_batch_base( + self, segments: list[TTSSegment], context: TTSPipelineContext, stream=False + ) -> Union[list[NP_AUDIO], Generator[list[NP_AUDIO], Any, None]]: + cached = self.get_cache(segments=segments, context=context) + if cached is not None: + if not stream: + return cached + + def _gen(): + yield cached + + return _gen() + + infer = self.get_infer(context) + + texts = [segment.text for segment in segments] + + seg0 = segments[0] + spk_emb = seg0.spk.emb if seg0.spk else None + top_P = seg0.top_p + top_K = seg0.top_k + temperature = seg0.temperature + # repetition_penalty = seg0.repetition_penalty + # max_new_token = seg0.max_new_token + prompt1 = seg0.prompt1 + prompt2 = seg0.prompt2 + prefix = seg0.prefix + # use_decoder = seg0.use_decoder + seed = seg0.infer_seed + chunk_size = context.infer_config.stream_chunk_size + + sr = 24000 + + if not stream: + with SeedContext(seed, cudnn_deterministic=False): + results = infer.generate_audio( + text=texts, + spk_emb=spk_emb, + top_P=top_P, + top_K=top_K, + temperature=temperature, + prompt1=prompt1, + prompt2=prompt2, + prefix=prefix, + ) + audio_arr: list[NP_AUDIO] = [ + # NOTE: data[0] 的意思是 立体声 => mono audio + (sr, np.empty(0)) if data is None else (sr, data[0]) + for data in results + ] + + self.set_cache(segments=segments, context=context, value=audio_arr) + return audio_arr + else: + + def _gen() -> Generator[list[NP_AUDIO], None, None]: + audio_arr_buff = None + with SeedContext(seed, cudnn_deterministic=False): + for results in infer.generate_audio_stream( + text=texts, + spk_emb=spk_emb, + top_P=top_P, + top_K=top_K, + temperature=temperature, + prompt1=prompt1, + prompt2=prompt2, + prefix=prefix, + stream_chunk_size=chunk_size, + ): + results = [ + ( + np.empty(0) + # None 应该是生成失败, size === 0 是生成结束 + if data is None or data.size == 0 + # NOTE: data[0] 的意思是 立体声 => mono audio + else data[0] + ) + for data in results + ] + audio_arr: list[NP_AUDIO] = [(sr, data) for data in results] + yield audio_arr + + if audio_arr_buff is None: + audio_arr_buff = audio_arr + else: + for i, data in enumerate(results): + sr1, before = audio_arr_buff[i] + buff = np.concatenate([before, data], axis=0) + audio_arr_buff[i] = (sr1, buff) + self.set_cache(segments=segments, context=context, value=audio_arr_buff) + + return _gen() diff --git a/modules/core/models/tts/InerCache.py b/modules/core/models/tts/InerCache.py new file mode 100644 index 0000000..dccbee0 --- /dev/null +++ b/modules/core/models/tts/InerCache.py @@ -0,0 +1,86 @@ +from typing import Dict +from cachetools import LRUCache +import torch +from cachetools import keys as cache_keys + +from modules.core.speaker import Speaker + + +def hash_tensor(tensor: torch.Tensor): + """ + NOTE: 在不同执行虚拟机中此函数可能不稳定 + NOTE: 但是用来计算 cache 足够了 + """ + + return hash(tuple(tensor.reshape(-1).tolist())) + + +class InferCache: + caches: Dict[str, LRUCache] = {} + + @classmethod + def get_cache(cls, model_id: str) -> LRUCache: + if model_id in InferCache.caches: + return InferCache.caches.get(model_id) + cache = LRUCache(maxsize=128) + InferCache.caches[model_id] = cache + return cache + + @classmethod + def get_hash_key(cls, *args, **kwargs): + for i, arg in enumerate(args): + if isinstance(arg, Speaker): + args[i] = str(arg.id) + if isinstance(arg, torch.Tensor): + args[i] = hash_tensor(arg) + + for key, value in kwargs.items(): + if isinstance(value, Speaker): + kwargs[key] = str(value.id) + if isinstance(value, torch.Tensor): + kwargs[key] = hash_tensor(value) + + cachekey = cache_keys.hashkey(*args, **kwargs) + return cachekey + + @classmethod + def get_cache_val(cls, model_id: str, *args, **kwargs): + key = cls.get_hash_key(*args, **kwargs) + cache = InferCache.get_cache(model_id) + + if key in cache: + return cache[key] + + return None + + @classmethod + def set_cache_val(cls, model_id: str, value, *args, **kwargs): + key = cls.get_hash_key(*args, **kwargs) + cache = InferCache.get_cache(model_id) + + cache[key] = value + + @classmethod + def cached(cls, model_id: str, should_cache: callable = None): + """ + 包装器 + + should_cache 用于提供 条件化 cache + """ + + def decorator(func): + def wrapper(*args, **kwargs): + if should_cache is not None and not should_cache(*args, **kwargs): + return func(*args, **kwargs) + + cached = cls.get_cache_val(model_id, *args, **kwargs) + if cached is not None: + return cached + + result = func(*args, **kwargs) + cls.set_cache_val(model_id, result, *args, **kwargs) + return result + + return wrapper + + return decorator diff --git a/modules/models.py b/modules/core/models/zoo/ChatTTS.py similarity index 88% rename from modules/models.py rename to modules/core/models/zoo/ChatTTS.py index ab21e89..f1b2390 100644 --- a/modules/models.py +++ b/modules/core/models/zoo/ChatTTS.py @@ -15,15 +15,17 @@ lock = threading.Lock() -def load_chat_tts_in_thread(): +def do_load_chat_tts(): global chat_tts if chat_tts: return logger.info("Loading ChatTTS models") chat_tts = ChatTTS.Chat() + device = devices.get_device_for("chattts") dtype = devices.dtype + chat_tts.load( compile=config.runtime_env_vars.compile, use_flash_attn=config.runtime_env_vars.flash_attn, @@ -31,10 +33,6 @@ def load_chat_tts_in_thread(): custom_path="./models/ChatTTS", device=device, dtype=dtype, - # dtype_vocos=devices.dtype_vocos, - # dtype_dvae=devices.dtype_dvae, - # dtype_gpt=devices.dtype_gpt, - # dtype_decoder=devices.dtype_decoder, ) # 如果 device 为 cpu 同时,又是 dtype == float16 那么报 warn @@ -50,8 +48,7 @@ def load_chat_tts_in_thread(): def load_chat_tts(): with lock: - if chat_tts is None: - load_chat_tts_in_thread() + do_load_chat_tts() if chat_tts is None: raise Exception("Failed to load ChatTTS models") return chat_tts @@ -83,4 +80,6 @@ def reload_chat_tts(): def get_tokenizer() -> LlamaTokenizer: chat_tts = load_chat_tts() tokenizer = chat_tts.pretrain_models["tokenizer"] + if not tokenizer: + raise Exception("Failed to load tokenizer") return tokenizer diff --git a/modules/ChatTTSInfer.py b/modules/core/models/zoo/ChatTTSInfer.py similarity index 89% rename from modules/ChatTTSInfer.py rename to modules/core/models/zoo/ChatTTSInfer.py index ad7874e..1c926ce 100644 --- a/modules/ChatTTSInfer.py +++ b/modules/core/models/zoo/ChatTTSInfer.py @@ -9,6 +9,7 @@ from modules import config from modules.ChatTTS.ChatTTS.core import Chat from modules.ChatTTS.ChatTTS.model import GPT +from modules.core.models import zoo from modules.utils.monkey_tqdm import disable_tqdm @@ -38,6 +39,8 @@ def del_all(d: Union[dict, list]): class ChatTTSInfer: + model_id = "chat-tts" + logger = logging.getLogger(__name__) current_infer = None @@ -46,6 +49,9 @@ def __init__(self, instance: Chat) -> None: self.instance = instance ChatTTSInfer.current_infer = self + if zoo.zoo_config.debug_generate: + self.logger.setLevel(logging.DEBUG) + def get_tokenizer(self) -> LlamaTokenizer: return self.instance.pretrain_models["tokenizer"] @@ -58,7 +64,7 @@ def interrupt(cls): def infer( self, - text: str, + text: Union[str, list[str]], stream=False, skip_refine_text=False, refine_text_only=False, @@ -84,7 +90,7 @@ def infer( @torch.inference_mode() def _infer( self, - text: str, + text: Union[str, list[str]], stream=False, skip_refine_text=False, refine_text_only=False, @@ -100,6 +106,14 @@ def _infer( # smooth_decoding = stream smooth_decoding = False + self.logger.debug( + f"Start infer: stream={stream}, skip_refine_text={skip_refine_text}, refine_text_only={refine_text_only}, use_decoder={use_decoder}, smooth_decoding={smooth_decoding}" + ) + self.logger.debug( + f"params_refine_text={params_refine_text}, params_infer_code={params_infer_code}" + ) + self.logger.debug(f"Text: {text}") + with torch.no_grad(): if not skip_refine_text: @@ -129,6 +143,7 @@ def _infer( wavs = self._decode_to_wavs(result, length, use_decoder) yield wavs else: + # NOTE: 貌似没什么用...? # smooth_decoding 即使用了滑动窗口的解码,每次都保留上一段的隐藏状态一起解码,并且保留上一段的音频长度用于截取 @dataclass(repr=False, eq=False) class WavWindow: @@ -251,7 +266,7 @@ def _decode_to_wavs( def _generate_audio( self, - text: str, + text: Union[str, list[str]], spk_emb: Union[None, torch.Tensor] = None, top_P=0.7, top_K=20, @@ -261,6 +276,7 @@ def _generate_audio( prompt1="", prompt2="", prefix="", + stream_chunk_size=96, use_decoder=True, stream=False, ): @@ -275,6 +291,7 @@ def _generate_audio( prompt1=prompt1, prompt2=prompt2, prefix=prefix, + stream_chunk_size=stream_chunk_size, ensure_non_empty=False, ) return self.infer( @@ -287,7 +304,7 @@ def _generate_audio( def generate_audio( self, - text: str, + text: Union[str, list[str]], spk_emb: Union[None, torch.Tensor] = None, top_P=0.7, top_K=20, @@ -314,11 +331,12 @@ def generate_audio( use_decoder=use_decoder, stream=False, ) - return [i for i in data if i is not None] + data = [i for i in data if i is not None] + return data def generate_audio_stream( self, - text: str, + text: Union[str, list[str]], spk_emb=None, top_P=0.7, top_K=20, @@ -328,9 +346,10 @@ def generate_audio_stream( prompt1="", prompt2="", prefix="", + stream_chunk_size=96, use_decoder=True, ) -> Generator[list[np.ndarray], None, None]: - gen = self._generate_audio( + gen: Generator[list[np.ndarray], None, None] = self._generate_audio( text=text, spk_emb=spk_emb, top_P=top_P, @@ -342,20 +361,23 @@ def generate_audio_stream( prompt2=prompt2, prefix=prefix, use_decoder=use_decoder, + stream_chunk_size=stream_chunk_size, stream=True, ) def _generator(): with disable_tqdm(enabled=config.runtime_env_vars.off_tqdm): - for audio in gen: - if audio is not None: - yield audio + for audio_arr in gen: + # 如果为空就用空 array 填充 + # NOTE: 因为长度不一定,所以某个位置可能是 None + audio_arr = [np.empty(0) if i is None else i for i in audio_arr] + yield audio_arr return _generator() def _refine_text( self, - text: str, + text: Union[str, list[str]], top_P=0.7, top_K=20, temperature=0.7, @@ -384,7 +406,7 @@ def _refine_text( def refine_text( self, - text: str, + text: Union[str, list[str]], top_P=0.7, top_K=20, temperature=0.7, @@ -405,7 +427,7 @@ def refine_text( def refine_text_stream( self, - text: str, + text: Union[str, list[str]], top_P=0.7, top_K=20, temperature=0.7, diff --git a/modules/core/models/zoo/__init__.py b/modules/core/models/zoo/__init__.py new file mode 100644 index 0000000..a4709b7 --- /dev/null +++ b/modules/core/models/zoo/__init__.py @@ -0,0 +1,2 @@ +from . import ChatTTS +from . import zoo_config diff --git a/modules/core/models/zoo/zoo_config.py b/modules/core/models/zoo/zoo_config.py new file mode 100644 index 0000000..4cc2ab8 --- /dev/null +++ b/modules/core/models/zoo/zoo_config.py @@ -0,0 +1 @@ +debug_generate = False diff --git a/modules/core/pipeline/__init__.py b/modules/core/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/pipeline/dcls.py b/modules/core/pipeline/dcls.py new file mode 100644 index 0000000..5d82f30 --- /dev/null +++ b/modules/core/pipeline/dcls.py @@ -0,0 +1,42 @@ +from typing import Literal, Optional +from dataclasses import dataclass + +from modules.core.handler.datacls.audio_model import AdjustConfig +from modules.core.handler.datacls.chattts_model import ChatTTSConfig, InferConfig +from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.handler.datacls.tn_model import TNConfig +from modules.core.speaker import Speaker + + +@dataclass +class TTSSegment: + _type: Literal["break", "audio"] + duration_s: int = 0 + + text: str = "" + temperature: float = 0.3 + top_p: float = 0.9 + top_k: int = 20 + infer_seed: int = 42 + prompt: str = "" + prompt1: str = "" + prompt2: str = "" + prefix: str = "" + + spk: Speaker = None + + +@dataclass +class TTSPipelineContext: + text: Optional[str] = None + ssml: Optional[str] = None + + spk: Optional[Speaker] = None + tts_config: ChatTTSConfig = ChatTTSConfig() + infer_config: InferConfig = InferConfig() + adjust_config: AdjustConfig = AdjustConfig() + enhancer_config: EnhancerConfig = EnhancerConfig() + + tn_config: TNConfig = TNConfig() + + stop: bool = False diff --git a/modules/core/pipeline/factory.py b/modules/core/pipeline/factory.py new file mode 100644 index 0000000..8b3ed68 --- /dev/null +++ b/modules/core/pipeline/factory.py @@ -0,0 +1,143 @@ +from numpy import ndarray +from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full +from modules.core.models.tts.ChatTtsModel import ChatTTSModel +from modules.core.pipeline.dcls import TTSSegment +from modules.core.pipeline.pipeline import TTSPipeline +from modules.core.pipeline.processor import ( + NP_AUDIO, + AudioProcessor, + TTSPipelineContext, + PreProcessor, +) +from modules.core.tn.ChatTtsTN import ChatTtsTN +from modules.utils import audio_utils +from modules.data import styles_mgr +from modules.core.speaker import Speaker, speaker_mgr + +import logging + +logger = logging.getLogger(__name__) + + +class EnhancerProcessor(AudioProcessor): + def _process_array( + self, audio: tuple[int, ndarray], context: TTSPipelineContext + ) -> tuple[int, ndarray]: + enhancer_config = context.enhancer_config + + if not enhancer_config.enabled: + return audio + nfe = enhancer_config.nfe + solver = enhancer_config.solver + lambd = enhancer_config.lambd + tau = enhancer_config.tau + + sample_rate, audio_data = audio + audio_data, sample_rate = apply_audio_enhance_full( + audio_data=audio_data, + sr=sample_rate, + nfe=nfe, + solver=solver, + lambd=lambd, + tau=tau, + ) + + return sample_rate, audio_data + + +class AdjusterProcessor(AudioProcessor): + def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUDIO: + sample_rate, audio_data = audio + adjust_config = context.adjust_config + + audio_data = audio_utils.apply_prosody_to_audio_data( + audio_data=audio_data, + rate=adjust_config.speed_rate, + pitch=adjust_config.pitch, + volume=adjust_config.volume_gain_db, + sr=sample_rate, + ) + return sample_rate, audio_data + + +class AudioNormalizer(AudioProcessor): + def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUDIO: + adjust_config = context.adjust_config + if not adjust_config.normalize: + return audio + sample_rate, audio_data = audio + sample_rate, audio_data = audio_utils.apply_normalize( + audio_data=audio_data, headroom=adjust_config.headroom, sr=sample_rate + ) + return sample_rate, audio_data + + +class ChatTtsTNProcessor(PreProcessor): + def process(self, segment: TTSSegment, context: TTSPipelineContext) -> TTSSegment: + segment.text = ChatTtsTN.normalize(text=segment.text, config=context.tn_config) + return segment + + +class ChatTtsStyleProcessor(PreProcessor): + """ + 计算合并 style/spk + """ + + def get_style_params(self, context: TTSPipelineContext): + style = context.tts_config.style + if not style: + return {} + params = styles_mgr.find_params_by_name(style) + return params + + def process(self, segment: TTSSegment, context: TTSPipelineContext) -> TTSSegment: + params = self.get_style_params(context) + segment.prompt = ( + segment.prompt or context.tts_config.prompt or params.get("prompt", "") + ) + segment.prompt1 = ( + segment.prompt1 or context.tts_config.prompt1 or params.get("prompt1", "") + ) + segment.prompt2 = ( + segment.prompt2 or context.tts_config.prompt2 or params.get("prompt2", "") + ) + segment.prefix = ( + segment.prefix or context.tts_config.prefix or params.get("prefix", "") + ) + + spk = segment.spk or context.spk + + if isinstance(spk, str): + if spk == "": + spk = None + else: + spk = speaker_mgr.get_speaker(spk) + if spk and not isinstance(spk, Speaker): + spk = None + logger.warn(f"Invalid spk: {spk}") + + segment.spk = spk + + return segment + + +class PipelineFactory: + @classmethod + def create(cls, ctx: TTSPipelineContext) -> TTSPipeline: + model_id = ctx.tts_config.mid + + if model_id == "chat-tts": + return cls.create_chattts_pipeline(ctx) + else: + raise Exception(f"Unknown model id: {model_id}") + + @classmethod + def create_chattts_pipeline(cls, ctx: TTSPipelineContext): + pipeline = TTSPipeline(ctx) + pipeline.add_module(ChatTtsTNProcessor()) + pipeline.add_module(EnhancerProcessor()) + pipeline.add_module(AdjusterProcessor()) + pipeline.add_module(AudioNormalizer()) + pipeline.add_module(ChatTtsStyleProcessor()) + pipeline.set_model(ChatTTSModel()) + return pipeline diff --git a/modules/core/pipeline/generate/BatchGenerate.py b/modules/core/pipeline/generate/BatchGenerate.py new file mode 100644 index 0000000..4a92471 --- /dev/null +++ b/modules/core/pipeline/generate/BatchGenerate.py @@ -0,0 +1,79 @@ +import threading +import numpy as np +from modules.core.models.TTSModel import TTSModel +from modules.core.pipeline.dcls import TTSPipelineContext +from modules.core.pipeline.generate.dcls import TTSBatch, TTSBucket +from modules.utils import audio_utils + + +class BatchGenerate: + def __init__( + self, buckets: list[TTSBucket], context: TTSPipelineContext, model: TTSModel + ) -> None: + self.buckets = buckets + self.model = model + self.context = context + self.batches = self.build_batches() + + self.done = threading.Event() + + def is_done(self): + return all([seg.done for batch in self.batches for seg in batch.segments]) + + def build_batches(self) -> list[TTSBatch]: + batch_size = self.context.infer_config.batch_size + + batches = [] + for bucket in self.buckets: + for i in range(0, len(bucket.segments), batch_size): + batch = bucket.segments[i : i + batch_size] + batches.append(TTSBatch(segments=batch)) + return batches + + def generate(self): + stream = self.context.infer_config.stream + for batch in self.batches: + is_break = batch.segments[0].seg._type == "break" + if is_break: + self.generate_break(batch) + continue + + if stream: + self.generate_batch_stream(batch) + else: + self.generate_batch(batch) + + self.done.set() + + def generate_break(self, batch: TTSBatch): + for seg in batch.segments: + seg.data = audio_utils.silence_np(seg.seg.duration_s) + seg.done = True + + def generate_batch(self, batch: TTSBatch): + model = self.model + segments = [audio.seg for audio in batch.segments] + results = model.generate_batch(segments=segments, context=self.context) + for audio, result in zip(batch.segments, results): + sr, data = result + audio.data = data + audio.sr = sr + audio.done = True + + def generate_batch_stream(self, batch: TTSBatch): + model = self.model + segments = [audio.seg for audio in batch.segments] + + for results in model.generate_batch_stream( + segments=segments, context=self.context + ): + for audio, result in zip(batch.segments, results): + sr, data = result + if data.size == 0: + audio.done = True + continue + audio.data = np.concatenate([audio.data, data], axis=0) + audio.sr = sr + + for seg in batch.segments: + seg.done = True diff --git a/modules/core/pipeline/generate/BatchSynth.py b/modules/core/pipeline/generate/BatchSynth.py new file mode 100644 index 0000000..e98940a --- /dev/null +++ b/modules/core/pipeline/generate/BatchSynth.py @@ -0,0 +1,59 @@ +import threading +from typing import Union +from modules.core.models.TTSModel import TTSModel +from modules.core.pipeline.dcls import TTSPipelineContext, TTSSegment +from modules.core.pipeline.generate.BatchGenerate import BatchGenerate +from modules.core.pipeline.generate.Bucketizer import Bucketizer +from modules.core.pipeline.generate.SynthSteamer import SynthStreamer +from modules.core.pipeline.generate.dcls import SynthAudio + + +class BatchSynth: + def __init__( + self, + input_segments: list[TTSSegment], + context: TTSPipelineContext, + model: TTSModel, + ) -> None: + self.segments = [SynthAudio(segment=segment) for segment in input_segments] + self.streamer = SynthStreamer( + segments=self.segments, context=context, model=model + ) + self.bucketizer = Bucketizer(segments=self.segments) + self.buckets = self.bucketizer.build_buckets() + self.generator = BatchGenerate( + buckets=self.buckets, context=context, model=model + ) + self.context = context + + self.thread1 = None + + def wait_done(self, timeout: Union[float, None] = None): + self.generator.done.wait(timeout=timeout) + + def is_done(self): + return self.generator.is_done() + + def sr(self): + return self.segments[0].sr + + def read(self): + return self.streamer.read() + + def start_generate(self): + sync_gen = self.context.infer_config.sync_gen + if sync_gen: + self.start_generate_sync() + else: + self.start_generate_async() + + def start_generate_async(self): + if self.thread1 is not None: + return + gen_t1 = threading.Thread(target=self.generator.generate, args=(), daemon=True) + gen_t1.start() + self.thread1 = gen_t1 + return gen_t1 + + def start_generate_sync(self): + self.generator.generate() diff --git a/modules/core/pipeline/generate/Bucketizer.py b/modules/core/pipeline/generate/Bucketizer.py new file mode 100644 index 0000000..ee1a96e --- /dev/null +++ b/modules/core/pipeline/generate/Bucketizer.py @@ -0,0 +1,45 @@ +import copy +import json +from typing import Dict, List +from modules.core.pipeline.generate.dcls import SynthAudio, TTSBucket +from modules.core.speaker import Speaker + + +class Bucketizer: + def __init__(self, segments: list[SynthAudio]) -> None: + self.segments = segments + + def seg_hash(self, audio: SynthAudio): + seg = audio.seg + temp_seg = copy.deepcopy(seg) + if isinstance(temp_seg.spk, Speaker): + temp_seg.spk = str(temp_seg.spk.id) + + json_data = temp_seg.__dict__ + + return hash( + json.dumps( + {k: v for k, v in json_data.items() if k != "text"}, sort_keys=True + ) + ) + + def build_buckets(self): + buckets: Dict[str, List[SynthAudio]] = {"": []} + for segment in self.segments: + if segment.seg._type == "break": + buckets[""].append(segment) + continue + key = self.seg_hash(segment) + if key not in buckets: + buckets[key] = [] + buckets[key].append(segment) + + break_segments = buckets.pop("") + audio_segments = list(buckets.values()) + # 根据 bucket 第一个 seg 的 index 排序,越小的越靠前 + audio_segments.sort(key=lambda x: self.segments.index(x[0])) + + return [ + TTSBucket(break_segments), + *[TTSBucket(segments) for segments in audio_segments], + ] diff --git a/modules/core/pipeline/generate/Chunker.py b/modules/core/pipeline/generate/Chunker.py new file mode 100644 index 0000000..91271ac --- /dev/null +++ b/modules/core/pipeline/generate/Chunker.py @@ -0,0 +1,89 @@ +from typing import List +from modules.core.pipeline.dcls import TTSPipelineContext, TTSSegment +from modules.core.pipeline.generate.SsmlNormalizer import SsmlNormalizer +from modules.core.ssml.SSMLParser import SSMLContext, get_ssml_parser_for +from modules.core.tools.SentenceSplitter import SentenceSplitter + + +class TTSChunker: + def __init__(self, context: TTSPipelineContext) -> None: + self.context = context + + def segments(self) -> List[TTSSegment]: + text = self.context.text + ssml = self.context.ssml + + if text is not None and len(text.strip()) > 0: + return self.text_segments() + if ssml is not None and len(ssml.strip()) > 0: + return self.ssml_segments() + raise ValueError("No input text or ssml") + + def text_segments(self): + spliter_threshold = self.context.infer_config.spliter_threshold + text = self.context.text + + temperature = self.context.tts_config.temperature + top_P = self.context.tts_config.top_p + top_K = self.context.tts_config.top_k + spk = self.context.spk + infer_seed = self.context.infer_config.seed + # 这个只有 chattts 需要,并且没必要填 false... + # use_decoder = self.context.tts_config.use_decoder + use_decoder = True + prompt1 = self.context.tts_config.prompt1 + prompt2 = self.context.tts_config.prompt2 + prefix = self.context.tts_config.prefix + + eos = self.context.infer_config.eos + + spliter = SentenceSplitter(spliter_threshold) + sentences = spliter.parse(text) + + text_segments = [ + TTSSegment( + _type="audio", + text=s + eos, + temperature=temperature, + top_p=top_P, + top_k=top_K, + spk=spk, + infer_seed=infer_seed, + prompt1=prompt1, + prompt2=prompt2, + prefix=prefix, + ) + for s in sentences + ] + return text_segments + + def create_ssml_ctx(self): + ctx = SSMLContext() + ctx.spk = self.context.spk + + ctx.style = self.context.tts_config.style + ctx.volume = self.context.adjust_config.volume_gain_db + ctx.rate = self.context.adjust_config.speed_rate + ctx.pitch = self.context.adjust_config.pitch + ctx.temp = self.context.tts_config.temperature + ctx.top_p = self.context.tts_config.top_p + ctx.top_k = self.context.tts_config.top_k + ctx.seed = self.context.infer_config.seed + # ctx.noramalize = self.context.tts_config.normalize + ctx.prompt1 = self.context.tts_config.prompt1 + ctx.prompt2 = self.context.tts_config.prompt2 + ctx.prefix = self.context.tts_config.prefix + + return ctx + + def ssml_segments(self): + ssml = self.context.ssml + eos = self.context.infer_config.eos + thr = self.context.infer_config.spliter_threshold + + parser = get_ssml_parser_for("0.1") + parser_ctx = self.create_ssml_ctx() + segments = parser.parse(ssml=ssml, root_ctx=parser_ctx) + normalizer = SsmlNormalizer(context=self.context, eos=eos, spliter_thr=thr) + segments = normalizer.normalize(segments) + return segments diff --git a/modules/core/pipeline/generate/SsmlNormalizer.py b/modules/core/pipeline/generate/SsmlNormalizer.py new file mode 100644 index 0000000..9127db6 --- /dev/null +++ b/modules/core/pipeline/generate/SsmlNormalizer.py @@ -0,0 +1,159 @@ +import copy +import re +from typing import List, Union +from modules.api.utils import calc_spk_style, to_number +from modules.core.pipeline.dcls import TTSPipelineContext, TTSSegment +from modules.core.ssml.SSMLParser import SSMLBreak, SSMLSegment +from modules.core.tools.SentenceSplitter import SentenceSplitter +from modules.utils import rng + + +class SsmlNormalizer: + """ + SSML segment normalizer + + input: List[SSMLSegment, SSMLBreak] => output: List[TTSSegment] + + and split the text in SSMLSegment into multiple segments if the text is too long + """ + + def __init__(self, context: TTSPipelineContext, eos="", spliter_thr=100): + self.batch_default_spk_seed = rng.np_rng() + self.batch_default_infer_seed = rng.np_rng() + self.eos = eos + self.spliter_thr = spliter_thr + self.context = context + + def append_eos(self, text: str): + text = text.strip() + eos_arr = ["[uv_break]", "[v_break]", "[lbreak]", "[llbreak]"] + has_eos = False + for eos in eos_arr: + if eos in text: + has_eos = True + break + if not has_eos: + text += self.eos + return text + + def convert_ssml_seg(self, segment: Union[SSMLSegment, SSMLBreak]) -> TTSSegment: + if isinstance(segment, SSMLBreak): + return TTSSegment(_type="break", duration_s=segment.attrs.duration) + + tts_config = self.context.tts_config + infer_config = self.context.infer_config + + params = segment.params + text = segment.text or "" + text = text.strip() + + if text: + text = self.append_eos(text) + + if params is not None: + return TTSSegment( + _type="voice", + text=text, + temperature=params.get("temperature", tts_config.temperature), + top_P=params.get("top_p", tts_config.top_p), + top_K=params.get("top_k", tts_config.top_k), + spk=params.get("spk", None), + infer_seed=params.get("seed", infer_config.seed), + prompt1=params.get("prompt1", ""), + prompt2=params.get("prompt2", ""), + prefix=params.get("prefix", ""), + ) + + text = str(text).strip() + + attrs = segment.attrs + spk = attrs.spk + style = attrs.style + + # FIXME: 这个逻辑有点...emmm 最好干掉 + ss_params = calc_spk_style(spk, style) + + if "spk" in ss_params: + spk = ss_params["spk"] + + seed = to_number(attrs.seed, int, ss_params.get("seed") or -1) + top_k = to_number(attrs.top_k, int, None) + top_p = to_number(attrs.top_p, float, None) + temp = to_number(attrs.temp, float, None) + + prompt1 = attrs.prompt1 or ss_params.get("prompt1") + prompt2 = attrs.prompt2 or ss_params.get("prompt2") + prefix = attrs.prefix or ss_params.get("prefix") + + seg = TTSSegment( + _type="voice", + text=text, + temperature=temp or tts_config.temperature, + top_p=top_p or tts_config.top_p, + top_k=top_k or tts_config.top_k, + spk=spk, + infer_seed=seed, + prompt1=prompt1, + prompt2=prompt2, + prefix=prefix, + ) + + # NOTE 每个batch的默认seed保证前后一致即使是没设置spk的情况 + if seg.spk == -1: + seg.spk = self.batch_default_spk_seed + if seg.infer_seed == -1: + seg.infer_seed = self.batch_default_infer_seed + + return seg + + def split_segments(self, segments: List[Union[SSMLSegment, SSMLBreak]]): + """ + 将 segments 中的 text 经过 spliter 处理成多个 segments + """ + spliter_threshold = self.context.infer_config.spliter_threshold + spliter = SentenceSplitter(threshold=spliter_threshold) + ret_segments: List[Union[SSMLSegment, SSMLBreak]] = [] + + for segment in segments: + if isinstance(segment, SSMLBreak): + ret_segments.append(segment) + continue + + text = segment.text + if not text: + continue + + sentences = spliter.parse(text) + for sentence in sentences: + seg = SSMLSegment( + text=sentence, + attrs=segment.attrs.copy(), + params=copy.copy(segment.params), + ) + ret_segments.append(seg) + setattr(seg, "_idx", len(ret_segments) - 1) + + def is_none_speak_segment(segment: SSMLSegment): + text = segment.text.strip() + regexp = r"\[[^\]]+?\]" + text = re.sub(regexp, "", text) + text = text.strip() + if not text: + return True + return False + + # 将 none_speak 合并到前一个 speak segment + for i in range(1, len(ret_segments)): + if is_none_speak_segment(ret_segments[i]): + ret_segments[i - 1].text += ret_segments[i].text + ret_segments[i].text = "" + # 移除空的 segment + ret_segments = [seg for seg in ret_segments if seg.text.strip()] + + return ret_segments + + def normalize( + self, segments: list[Union[SSMLSegment, SSMLBreak]] + ) -> list[TTSSegment]: + segments = self.split_segments(segments) + return [self.convert_ssml_seg(seg) for seg in segments] diff --git a/modules/core/pipeline/generate/SynthSteamer.py b/modules/core/pipeline/generate/SynthSteamer.py new file mode 100644 index 0000000..039b1c1 --- /dev/null +++ b/modules/core/pipeline/generate/SynthSteamer.py @@ -0,0 +1,37 @@ +import numpy as np +from modules.core.models.TTSModel import TTSModel +from modules.core.pipeline.dcls import TTSPipelineContext +from modules.core.pipeline.generate.dcls import SynthAudio + + +class SynthStreamer: + + def __init__( + self, segments: list[SynthAudio], context: TTSPipelineContext, model: TTSModel + ) -> None: + self.segments = segments + self.context = context + self.model = model + self.output_wav = np.empty(0) + + def flush(self): + output_wav = np.empty(0) + + for seg in self.segments: + data = seg.data + if data.size == 0: + break + output_wav = np.concatenate((output_wav, data), axis=0) + if not seg.done: + break + + self.output_wav = output_wav + return output_wav + + def write(self): + raise NotImplementedError + + def read(self) -> np.ndarray: + cursor = self.output_wav.size + self.flush() + return self.output_wav[cursor:] diff --git a/modules/core/pipeline/generate/__init__.py b/modules/core/pipeline/generate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/pipeline/generate/dcls.py b/modules/core/pipeline/generate/dcls.py new file mode 100644 index 0000000..9074527 --- /dev/null +++ b/modules/core/pipeline/generate/dcls.py @@ -0,0 +1,23 @@ +import numpy as np +from modules.core.pipeline.dcls import TTSSegment + + +class SynthAudio: + def __init__(self, segment: TTSSegment) -> None: + self.seg = segment + self.data = np.empty(0) + self.sr = 24000 + self.done = False + + +class TTSBucket: + def __init__(self, segments: list[SynthAudio]) -> None: + self.segments = segments + + +class TTSBatch: + def __init__(self, segments: list[SynthAudio]) -> None: + self.segments = segments + + def is_done(self): + return all([result.done for result in self.segments]) diff --git a/modules/core/pipeline/pipeline.py b/modules/core/pipeline/pipeline.py new file mode 100644 index 0000000..1145efe --- /dev/null +++ b/modules/core/pipeline/pipeline.py @@ -0,0 +1,103 @@ +from time import sleep +from typing import Generator, Literal, Union + +from pydub import AudioSegment + +from modules.core.models.TTSModel import TTSModel +from modules.core.pipeline.dcls import TTSSegment +from modules.core.pipeline.generate.BatchSynth import BatchSynth +from modules.core.pipeline.generate.Chunker import TTSChunker +from modules.core.pipeline.processor import ( + AUDIO, + NP_AUDIO, + AudioProcessor, + PreProcessor, + TTSPipelineContext, +) +from modules.utils import audio_utils + + +class TTSPipeline: + def __init__(self, context: TTSPipelineContext): + self.modules: list[Union[AudioProcessor, PreProcessor]] = [] + self.model: TTSModel = None + self.context = context + + def add_module(self, module): + self.modules.append(module) + + def set_model(self, model): + self.model = model + + def create_synth(self): + chunker = TTSChunker(context=self.context) + segments = chunker.segments() + # 其实这个在 chunker 之前调用好点...但是有副作用所以放在后面 + segments = [self.process_pre(seg) for seg in segments] + + synth = BatchSynth( + input_segments=segments, context=self.context, model=self.model + ) + return synth + + def generate(self) -> NP_AUDIO: + synth = self.create_synth() + synth.start_generate() + synth.wait_done() + audio = synth.sr(), synth.read() + return self.process_np_audio(audio) + + def generate_stream(self) -> Generator[NP_AUDIO, None, None]: + synth = self.create_synth() + synth.start_generate() + while not synth.is_done(): + data = synth.read() + if data.size > 0: + audio = synth.sr(), data + yield self.process_np_audio(audio) + # TODO: replace with threading.Event + sleep(0.1) + data = synth.read() + if data.size > 0: + audio = synth.sr(), data + yield self.process_np_audio(audio) + + def process_np_audio(self, audio: NP_AUDIO) -> NP_AUDIO: + audio = self.process_audio(audio) + return self.ensure_audio_type(audio, "ndarray") + + def ensure_audio_type( + self, audio: AUDIO, output_type: Literal["ndarray", "segment"] + ): + if output_type == "segment": + audio = self._to_audio_segment(audio) + elif output_type == "ndarray": + audio = self._to_ndarray(audio) + else: + raise ValueError(f"Invalid output_type: {output_type}") + return audio + + def _to_audio_segment(self, audio: AUDIO) -> AudioSegment: + if isinstance(audio, tuple): + sr, data = audio + audio = audio_utils.ndarray_to_segment(ndarray=data, frame_rate=sr) + return audio + + def _to_ndarray(self, audio: AUDIO) -> NP_AUDIO: + if isinstance(audio, AudioSegment): + sr = audio.frame_rate + audio = audio_utils.audiosegment_to_librosawav(audio) + return sr, audio + return audio + + def process_pre(self, seg: TTSSegment): + for module in self.modules: + if isinstance(module, PreProcessor): + seg = module.process(segment=seg, context=self.context) + return seg + + def process_audio(self, audio: AUDIO): + for module in self.modules: + if isinstance(module, AudioProcessor): + audio = module.process(audio=audio, context=self.context) + return audio diff --git a/modules/core/pipeline/processor.py b/modules/core/pipeline/processor.py new file mode 100644 index 0000000..589828b --- /dev/null +++ b/modules/core/pipeline/processor.py @@ -0,0 +1,38 @@ +from typing import Union + +import numpy as np +from pydub import AudioSegment + +from modules.core.pipeline.dcls import TTSPipelineContext, TTSSegment +from modules.utils import audio_utils as audio_utils + +NP_AUDIO = tuple[int, np.ndarray] +AUDIO = Union[NP_AUDIO, AudioSegment] + + +class PreProcessor: + def process(self, segment: TTSSegment, context: TTSPipelineContext) -> TTSSegment: + raise NotImplementedError + + +class AudioProcessor: + def process(self, audio: AUDIO, context: TTSPipelineContext) -> AUDIO: + if isinstance(audio, tuple): + return self._process_array(audio, context) + elif isinstance(audio, AudioSegment): + return self._process_segment(audio, context) + else: + raise ValueError("Unsupported audio type") + + def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUDIO: + sr, data = audio + segment = audio_utils.ndarray_to_segment(ndarray=data, frame_rate=sr) + processed_segment = self._process_segment(segment, context) + return audio_utils.audiosegment_to_librosawav(processed_segment) + + def _process_segment( + self, audio: AudioSegment, context: TTSPipelineContext + ) -> AudioSegment: + ndarray = audio_utils.audiosegment_to_librosawav(audio) + processed_ndarray = self._process_array(ndarray, context) + return audio_utils.ndarray_to_segment(processed_ndarray) diff --git a/modules/core/pipeline/processors/Adjuster.py b/modules/core/pipeline/processors/Adjuster.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/pipeline/processors/Denoiser.py b/modules/core/pipeline/processors/Denoiser.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/pipeline/processors/Enhancer.py b/modules/core/pipeline/processors/Enhancer.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/pipeline/processors/TN.py b/modules/core/pipeline/processors/TN.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/pipeline/processors/__init__.py b/modules/core/pipeline/processors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/speaker.py b/modules/core/speaker.py new file mode 100644 index 0000000..5938f60 --- /dev/null +++ b/modules/core/speaker.py @@ -0,0 +1,197 @@ +import os +import uuid +from typing import Optional, Union + +import numpy as np +import torch +from box import Box + +from modules.core.models import zoo +from modules.utils.SeedContext import SeedContext + + +def create_speaker_from_seed(seed): + chat_tts = zoo.ChatTTS.load_chat_tts() + with SeedContext(seed, True): + emb = chat_tts._sample_random_speaker() + return emb + + +class Speaker: + @staticmethod + def from_file(file_like): + speaker = torch.load(file_like, map_location=torch.device("cpu")) + speaker.fix() + return speaker + + @staticmethod + def from_tensor(tensor): + speaker = Speaker(seed_or_tensor=-2) + speaker.emb = tensor + return speaker + + @staticmethod + def from_seed(seed: int): + speaker = Speaker(seed_or_tensor=seed) + speaker.emb = create_speaker_from_seed(seed) + return speaker + + def __init__( + self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe="" + ): + self.id = uuid.uuid4() + self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor + self.name = name + self.gender = gender + self.describe = describe + self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor + + # TODO replace emb => tokens + self.tokens: list[torch.Tensor] = [] + + # TODO refer_wav, sample_wav + self.refer_wav: Optional[np.ndarray] = None + self.sample_wav: Optional[np.ndarray] = None + + # TODO 适用于某个 model 的 speaker + self.model_name: str = "" + self.model_hash: str = "" + + def to_json(self, with_emb=False): + return Box( + **{ + "id": str(self.id), + "seed": self.seed, + "name": self.name, + "gender": self.gender, + "describe": self.describe, + "emb": self.emb.tolist() if with_emb else None, + } + ) + + def fix(self): + is_update = False + if "id" not in self.__dict__: + setattr(self, "id", uuid.uuid4()) + is_update = True + if "seed" not in self.__dict__: + setattr(self, "seed", -2) + is_update = True + if "name" not in self.__dict__: + setattr(self, "name", "") + is_update = True + if "gender" not in self.__dict__: + setattr(self, "gender", "*") + is_update = True + if "describe" not in self.__dict__: + setattr(self, "describe", "") + is_update = True + + return is_update + + def __hash__(self): + return hash(str(self.id)) + + def __eq__(self, other): + if not isinstance(other, Speaker): + return False + return str(self.id) == str(other.id) + + +# 每个speaker就是一个 emb 文件 .pt +# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker +# 可以 用 seed 创建一个 speaker +# 可以 刷新列表 读取所有 speaker +# 可以列出所有 speaker +class SpeakerManager: + def __init__(self): + self.speakers = {} + self.speaker_dir = "./data/speakers/" + self.refresh_speakers() + + def refresh_speakers(self): + self.speakers = {} + for speaker_file in os.listdir(self.speaker_dir): + if speaker_file.endswith(".pt"): + self.speakers[speaker_file] = Speaker.from_file( + self.speaker_dir + speaker_file + ) + # 检查是否有被删除的,同步到 speakers + for fname, spk in self.speakers.items(): + if not os.path.exists(self.speaker_dir + fname): + del self.speakers[fname] + + def list_speakers(self) -> list[Speaker]: + return list(self.speakers.values()) + + def create_speaker_from_seed(self, seed, name="", gender="", describe=""): + if name == "": + name = seed + filename = name + ".pt" + speaker = Speaker(seed, name=name, gender=gender, describe=describe) + speaker.emb = create_speaker_from_seed(seed) + torch.save(speaker, self.speaker_dir + filename) + self.refresh_speakers() + return speaker + + def create_speaker_from_tensor( + self, tensor, filename="", name="", gender="", describe="" + ): + if filename == "": + filename = name + speaker = Speaker( + seed_or_tensor=-2, name=name, gender=gender, describe=describe + ) + if isinstance(tensor, torch.Tensor): + speaker.emb = tensor + if isinstance(tensor, list): + speaker.emb = torch.tensor(tensor) + torch.save(speaker, self.speaker_dir + filename + ".pt") + self.refresh_speakers() + return speaker + + def get_speaker(self, name) -> Union[Speaker, None]: + for speaker in self.speakers.values(): + if speaker.name == name: + return speaker + return None + + def get_speaker_by_id(self, id) -> Union[Speaker, None]: + for speaker in self.speakers.values(): + if str(speaker.id) == str(id): + return speaker + return None + + def get_speaker_filename(self, id: str): + filename = None + for fname, spk in self.speakers.items(): + if str(spk.id) == str(id): + filename = fname + break + return filename + + def update_speaker(self, speaker: Speaker): + filename = None + for fname, spk in self.speakers.items(): + if str(spk.id) == str(speaker.id): + filename = fname + break + + if filename: + torch.save(speaker, self.speaker_dir + filename) + self.refresh_speakers() + return speaker + else: + raise ValueError("Speaker not found for update") + + def save_all(self): + for speaker in self.speakers.values(): + filename = self.get_speaker_filename(speaker.id) + torch.save(speaker, self.speaker_dir + filename) + # self.refresh_speakers() + + def __len__(self): + return len(self.speakers) + + +speaker_mgr = SpeakerManager() diff --git a/modules/ssml_parser/SSMLParser.py b/modules/core/ssml/SSMLParser.py similarity index 92% rename from modules/ssml_parser/SSMLParser.py rename to modules/core/ssml/SSMLParser.py index 5fc27ef..3848f0a 100644 --- a/modules/ssml_parser/SSMLParser.py +++ b/modules/core/ssml/SSMLParser.py @@ -1,5 +1,5 @@ import logging -from typing import List, Union +from typing import List, Literal, Union from box import Box from lxml import etree @@ -42,6 +42,10 @@ def __init__(self, duration_ms: Union[str, int, float]): class SSMLParser: + """ + 基础类,在其他模块中不应该手动创建 + 用法看 create_ssml_v01_parser + """ def __init__(self): self.logger = logging.getLogger(__name__) @@ -55,10 +59,11 @@ def decorator(func): return decorator - def parse(self, ssml: str) -> List[Union[SSMLSegment, SSMLBreak]]: + def parse( + self, ssml: str, root_ctx=SSMLContext() + ) -> List[Union[SSMLSegment, SSMLBreak]]: root = etree.fromstring(ssml) - root_ctx = SSMLContext() segments: List[Union[SSMLSegment, SSMLBreak]] = [] self.resolve(root, root_ctx, segments) @@ -76,7 +81,7 @@ def resolve( resolver(element, context, segments, self) -def create_ssml_parser(): +def create_ssml_v01_parser(): parser = SSMLParser() @parser.resolver("speak") @@ -172,8 +177,14 @@ def tag_prosody( return parser +def get_ssml_parser_for(version: Literal["0.1"]): + if version == "0.1": + return create_ssml_v01_parser() + raise ValueError(f"Unsupported SSML version {version}") + + if __name__ == "__main__": - parser = create_ssml_parser() + parser = create_ssml_v01_parser() ssml = """ diff --git a/modules/SynthesizeSegments.py b/modules/core/ssml/SynthesizeSSML.py similarity index 92% rename from modules/SynthesizeSegments.py rename to modules/core/ssml/SynthesizeSSML.py index c19f388..b54ba28 100644 --- a/modules/SynthesizeSegments.py +++ b/modules/core/ssml/SynthesizeSSML.py @@ -12,12 +12,12 @@ from modules import generate_audio from modules.api.utils import calc_spk_style +from modules.core.speaker import Speaker +from modules.core.ssml.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment +from modules.core.tools.SentenceSplitter import SentenceSplitter from modules.normalization import text_normalize -from modules.SentenceSplitter import SentenceSplitter -from modules.speaker import Speaker -from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment from modules.utils import rng -from modules.utils.audio import apply_prosody_to_audio_segment +from modules.utils.audio_utils import apply_prosody_to_audio_segment, pydub_to_np logger = logging.getLogger(__name__) @@ -90,7 +90,7 @@ def __init__(self, *args, **kwargs): self.prefix = kwargs.get("prefix", "") -class SynthesizeSegments: +class SynthesizeSSML: def __init__(self, batch_size: int = 8, eos="", spliter_thr=100): self.batch_size = batch_size self.batch_default_spk_seed = rng.np_rng() @@ -288,7 +288,7 @@ def is_none_speak_segment(segment: SSMLSegment): return ret_segments - def synthesize_segments( + def synthesize( self, segments: List[Union[SSMLSegment, SSMLBreak]] ) -> List[AudioSegment]: segments = self.split_segments(segments) @@ -305,6 +305,19 @@ def synthesize_segments( return audio_segments + def synthesize_combine( + self, segments: List[Union[SSMLSegment, SSMLBreak]] + ) -> AudioSegment: + audio_segments = self.synthesize(segments) + return combine_audio_segments(audio_segments) + + def synthesize_combine_np( + self, segments: List[Union[SSMLSegment, SSMLBreak]] + ) -> tuple[int, np.ndarray]: + combined_audio = self.synthesize_combine(segments) + sample_rate, audio_data = pydub_to_np(combined_audio) + return sample_rate, audio_data + # 示例使用 if __name__ == "__main__": @@ -323,8 +336,8 @@ def synthesize_segments( SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()), ] - synthesizer = SynthesizeSegments(batch_size=2) - audio_segments = synthesizer.synthesize_segments(ssml_segments) + synthesizer = SynthesizeSSML(batch_size=2) + audio_segments = synthesizer.synthesize(ssml_segments) print(audio_segments) combined_audio = combine_audio_segments(audio_segments) combined_audio.export("output.wav", format="wav") diff --git a/modules/core/ssml/__init__.py b/modules/core/ssml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/normalization.py b/modules/core/tn/ChatTtsTN.py similarity index 59% rename from modules/normalization.py rename to modules/core/tn/ChatTtsTN.py index eacc596..396efc9 100644 --- a/modules/normalization.py +++ b/modules/core/tn/ChatTtsTN.py @@ -3,49 +3,73 @@ import emojiswitch import ftfy +from pywrapfst import FstOpError +from tn.english.normalizer import Normalizer as EnNormalizer -from modules import models -from modules.utils.detect_lang import guess_lang +from modules.core.models import zoo +from modules.core.tn.TNPipeline import GuessLang, TNPipeline +from modules.repos_static.zh_normalization.text_normlization import TextNormalizer from modules.utils.HomophonesReplacer import HomophonesReplacer from modules.utils.html import remove_html_tags as _remove_html_tags from modules.utils.markdown import markdown_to_text -from modules.utils.zh_normalization.text_normlization import TextNormalizer -# 是否关闭 unk token 检查 -# NOTE: 单测的时候用于跳过模型加载 DISABLE_UNK_TOKEN_CHECK = False - -post_normalize_pipeline = [] -pre_normalize_pipeline = [] - - -def post_normalize(): - def decorator(func): - post_normalize_pipeline.append(func) - return func - - return decorator - - -def pre_normalize(): - def decorator(func): - pre_normalize_pipeline.append(func) - return func - - return decorator - - -def apply_pre_normalize(text): - for func in pre_normalize_pipeline: - text = func(text) - return text - - -def apply_post_normalize(text): - for func in post_normalize_pipeline: - text = func(text) - return text +ChatTtsTN = TNPipeline() +ChatTtsTN.freeze_strs = [ + "[Sasr]", + "[Pasr]", + "[Easr]", + "[Stts]", + "[Ptts]", + "[Etts]", + "[Sbreak]", + "[Pbreak]", + "[Ebreak]", + "[uv_break]", + "[v_break]", + "[lbreak]", + "[llbreak]", + "[undefine]", + "[laugh]", + "[spk_emb]", + "[empty_spk]", + "[music]", + "[pure]", + "[break_0]", + "[break_1]", + "[break_2]", + "[break_3]", + "[break_4]", + "[break_5]", + "[break_6]", + "[break_7]", + "[laugh_0]", + "[laugh_1]", + "[laugh_2]", + "[oral_0]", + "[oral_1]", + "[oral_2]", + "[oral_3]", + "[oral_4]", + "[oral_5]", + "[oral_6]", + "[oral_7]", + "[oral_8]", + "[oral_9]", + "[speed_0]", + "[speed_1]", + "[speed_2]", + "[speed_3]", + "[speed_4]", + "[speed_5]", + "[speed_6]", + "[speed_7]", + "[speed_8]", + "[speed_9]", +] + +# ------- UTILS --------- def is_markdown(text): @@ -107,109 +131,37 @@ def is_markdown(text): "·": " ", } -character_to_word = { - " & ": " and ", -} - -## ---------- post normalize ---------- - - -@post_normalize() -def apply_character_to_word(text): - for k, v in character_to_word.items(): - text = text.replace(k, v) - return text - - -@post_normalize() -def apply_character_map(text): - translation_table = str.maketrans(character_map) - return text.translate(translation_table) - - -@post_normalize() -def apply_emoji_map(text): - lang = guess_lang(text) - return emojiswitch.demojize(text, delimiters=("", ""), lang=lang) - +# ----------------------- -@post_normalize() -def insert_spaces_between_uppercase(s): - # 使用正则表达式在每个相邻的大写字母之间插入空格 - return re.sub( - r"(?<=[A-Z])(?=[A-Z])|(?<=[a-z])(?=[A-Z])|(?<=[\u4e00-\u9fa5])(?=[A-Z])|(?<=[A-Z])(?=[\u4e00-\u9fa5])", - " ", - s, - ) - - -@post_normalize() -def replace_unk_tokens(text): - """ - 把不在字典里的字符替换为 " , " - """ - if DISABLE_UNK_TOKEN_CHECK: - return text - chat_tts = models.load_chat_tts() - if "tokenizer" not in chat_tts.pretrain_models: - # 这个地方只有在 huggingface spaces 中才会触发 - # 因为 hugggingface 自动处理模型卸载加载,所以如果拿不到就算了... - return text - tokenizer = chat_tts.pretrain_models["tokenizer"] - vocab = tokenizer.get_vocab() - vocab_set = set(vocab.keys()) - # 添加所有英语字符 - vocab_set.update(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")) - vocab_set.update(set(" \n\r\t")) - replaced_chars = [char if char in vocab_set else " , " for char in text] - output_text = "".join(replaced_chars) - return output_text - -homo_replacer = HomophonesReplacer( - map_file_path="./modules/ChatTTS/ChatTTS/res/homophones_map.json" -) - - -@post_normalize() -def replace_homophones(text): - lang = guess_lang(text) - if lang == "zh": - text = homo_replacer.replace(text) - return text - - -## ---------- pre normalize ---------- - - -@pre_normalize() -def html_unescape(text): +@ChatTtsTN.block() +def html_unescape(text: str, guess_lang: GuessLang): text = html.unescape(text) text = html.unescape(text) return text -@pre_normalize() -def fix_text(text): +@ChatTtsTN.block() +def fix_text(text: str, guess_lang: GuessLang): return ftfy.fix_text(text=text) -@pre_normalize() -def apply_markdown_to_text(text): +@ChatTtsTN.block() +def apply_markdown_to_text(text: str, guess_lang: GuessLang): if is_markdown(text): text = markdown_to_text(text) return text -@pre_normalize() -def remove_html_tags(text): +@ChatTtsTN.block() +def remove_html_tags(text: str, guess_lang: GuessLang): return _remove_html_tags(text) # 将 "xxx" => \nxxx\n # 将 'xxx' => \nxxx\n -@pre_normalize() -def replace_quotes(text): +@ChatTtsTN.block() +def replace_quotes(text: str, guess_lang: GuessLang): repl = r"\n\1\n" patterns = [ ['"', '"'], @@ -222,83 +174,102 @@ def replace_quotes(text): return text -def ensure_suffix(a: str, b: str, c: str): - a = a.strip() - if not a.endswith(b): - a += c - return a +# ---- main normalize ---- -email_domain_map = { - "outlook.com": "Out look", - "hotmail.com": "Hot mail", - "yahoo.com": "雅虎", -} +@ChatTtsTN.block(name="tx_zh", enabled=True) +def tx_normalize(text: str, guss_lang: GuessLang): + if guss_lang.zh_or_en != "zh": + return text + # NOTE: 这个是魔改过的 TextNormalizer 来自 PaddlePaddle + tx = TextNormalizer() + # NOTE: 为什么要分行?因为我们需要保留 "\n" 作为 chunker 的分割信号 + lines = [line for line in text.split("\n") if line.strip() != ""] + texts: list[str] = [] + for line in lines: + ts = tx.normalize(line) + texts.append("".join(ts)) + return "\n".join(texts) + + +@ChatTtsTN.block(name="wetext_en", enabled=True) +def wetext_normalize(text: str, guss_lang: GuessLang): + if guss_lang.zh_or_en == "en": + en_tn_model = EnNormalizer(overwrite_cache=False) + try: + return en_tn_model.normalize(text) + except FstOpError: + # NOTE: 不太理解为什么 tn 都能出错... + pass + return text + +# ---- end main normalize ---- -# 找到所有 email 并将 name 分割为单个字母,@替换为 at ,. 替换为 dot,常见域名替换为单词 -# -# 例如: -# zhzluke96@outlook.com => z h z l u k e 9 6 at out look dot com -def email_detect(text): - email_pattern = re.compile(r"([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})") - def replace(match): - email = match.group(1) - name, domain = email.split("@") - name = " ".join(name) - if domain in email_domain_map: - domain = email_domain_map[domain] - domain = domain.replace(".", " dot ") - return f"{name} at {domain}" +@ChatTtsTN.block() +def apply_character_map(text: str, guess_lang: GuessLang): + translation_table = str.maketrans(character_map) + return text.translate(translation_table) - return email_pattern.sub(replace, text) +@ChatTtsTN.block() +def apply_emoji_map(text: str, guess_lang: GuessLang): + return emojiswitch.demojize(text, delimiters=("", ""), lang=guess_lang.zh_or_en) -def sentence_normalize(sentence_text: str): - # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization - tx = TextNormalizer() - # 匹配 \[.+?\] 的部分 - pattern = re.compile(r"(\[.+?\])|([^[]+)") +@ChatTtsTN.block() +def insert_spaces_between_uppercase(text: str, guess_lang: GuessLang): + # 使用正则表达式在每个相邻的大写字母之间插入空格 + return re.sub( + r"(?<=[A-Z])(?=[A-Z])|(?<=[a-z])(?=[A-Z])|(?<=[\u4e00-\u9fa5])(?=[A-Z])|(?<=[A-Z])(?=[\u4e00-\u9fa5])", + " ", + text, + ) - def normalize_part(part): - sentences = tx.normalize(part) if guess_lang(part) == "zh" else [part] - dest_text = "" - for sentence in sentences: - sentence = apply_post_normalize(sentence) - dest_text += sentence - return dest_text - def replace(match): - if match.group(1): - return f" {match.group(1)} " - else: - return normalize_part(match.group(2)) +@ChatTtsTN.block() +def replace_unk_tokens(text: str, guess_lang: GuessLang): + """ + 把不在字典里的字符替换为 " , " - result = pattern.sub(replace, sentence_text) + FIXME: 总感觉不太好...但是没有遇到问题的话暂时留着... + """ + if DISABLE_UNK_TOKEN_CHECK: + return text + chat_tts = zoo.ChatTTS.load_chat_tts() + if "tokenizer" not in chat_tts.pretrain_models: + # 这个地方只有在 huggingface spaces 中才会触发 + # 因为 hugggingface 自动处理模型卸载加载,所以如果拿不到就算了... + return text + tokenizer = zoo.ChatTTS.get_tokenizer() + vocab = tokenizer.get_vocab() + vocab_set = set(vocab.keys()) + # 添加所有英语字符 + vocab_set.update(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")) + vocab_set.update(set(" \n\r\t")) + replaced_chars = [char if char in vocab_set else " , " for char in text] + output_text = "".join(replaced_chars) + return output_text - # NOTE: 加了会有杂音... - # if is_end: - # 加这个是为了防止吞字 - # result = ensure_suffix(result, "[uv_break]", "。。。[uv_break]。。。") - return result +homo_replacer = HomophonesReplacer( + map_file_path="./modules/ChatTTS/ChatTTS/res/homophones_map.json" +) -def text_normalize(text, is_end=False): - text = apply_pre_normalize(text) - lines = text.split("\n") - lines = [line.strip() for line in lines] - lines = [line for line in lines if line] - lines = [sentence_normalize(line) for line in lines] - content = "\n".join(lines) - return content +@ChatTtsTN.block() +def replace_homophones(text: str, guess_lang: GuessLang): + if guess_lang.zh_or_en == "zh": + text = homo_replacer.replace(text) + return text if __name__ == "__main__": from modules.devices import devices + DISABLE_UNK_TOKEN_CHECK = True + devices.reset_device() test_cases = [ "ChatTTS是专门为对话场景设计的文本转语音模型,例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。在HuggingFace中开源的版本为4万小时训练且未SFT的版本.", @@ -343,4 +314,4 @@ def text_normalize(text, is_end=False): ] for i, test_case in enumerate(test_cases): - print(f"case {i}:\n", {"x": text_normalize(test_case, is_end=True)}) + print(f"case {i}:\n", {"x": ChatTtsTN.normalize(test_case)}) diff --git a/modules/core/tn/TNPipeline.py b/modules/core/tn/TNPipeline.py new file mode 100644 index 0000000..8ce58b7 --- /dev/null +++ b/modules/core/tn/TNPipeline.py @@ -0,0 +1,141 @@ +from typing import Callable, Dict, Literal, Optional + +from langdetect import LangDetectException, detect_langs +from pydantic import BaseModel + +from modules.core.handler.datacls.tn_model import TNConfig +from modules.utils.detect_lang import guess_lang + + +class GuessLang(BaseModel): + zh_or_en: Literal["zh", "en"] + detected: Dict[str, float] + + +class TNBlock: + + def __init__(self, name: str): + self.name = name + self.enabled = True + + def process(self, text: str, guess_lang: GuessLang): + raise NotImplementedError + + +TNBlockFnType = Callable[[str, GuessLang], str] + + +class TNBlockFn(TNBlock): + def __init__(self, name: str, fn: TNBlockFnType): + super().__init__(name) + self.fn = fn + + def process(self, text: str, guess_lang: GuessLang): + return self.fn(text, guess_lang) + + +class TNText(BaseModel): + text: str + type: Literal["normal", "freeze"] + + +class TNPipeline: + """文本归一化管道类""" + + SEP_CHAR = "\n" + + def __init__(self): + self.blocks: list[TNBlock] = [] + self.freeze_strs: list[str] = [] + + def block(self, name: str = None, enabled: bool = True): + block = TNBlockFn(name=name, fn=None) + block.enabled = enabled + self.blocks.append(block) + + def decorator(fn: TNBlockFnType): + block.fn = fn + if not block.name: + block.name = fn.__name__ + return fn + + return decorator + + def split_string_with_freeze( + self, text: str, freeze_strs: list[str] + ) -> list[TNText]: + if not freeze_strs: + return [TNText(text=text, type="normal")] + + result: list[TNText] = [] + buffer = "" + + for char in text: + buffer += char + + for freeze_str in freeze_strs: + if buffer.endswith(freeze_str): + result.append( + TNText(text=buffer[: -len(freeze_str)], type="normal") + ) + result.append(TNText(text=freeze_str, type="freeze")) + buffer = "" + break + + if buffer: + result.append(TNText(text=buffer, type="normal")) + + return result + + def normalize(self, text: str, config: Optional[TNConfig] = None) -> str: + texts: list[TNText] = self.split_string_with_freeze(text, self.freeze_strs) + + result = "" + + for tn_text in texts: + if tn_text.type == "normal": + result += self._normalize(tn_text.text, config) + else: + result += tn_text.text + result += self.SEP_CHAR + + return result.strip() + + def guess_langs(self, text: str): + zh_or_en = guess_lang(text) + try: + detected_langs = detect_langs(text) + detected = {lang.lang: lang.prob for lang in detected_langs} + except LangDetectException: + detected = { + "zh": 1.0 if zh_or_en == "zh" else 0.0, + "en": 1.0 if zh_or_en == "en" else 0.0, + } + guess = GuessLang(zh_or_en=zh_or_en, detected=detected) + return guess + + def _normalize(self, text: str, config: Optional[TNConfig] = TNConfig()): + if config is None: + config = TNConfig() + enabled_block = config.enabled if config.enabled else [] + disabled_block = config.disabled if config.disabled else [] + + guess = self.guess_langs(text) + + for block in self.blocks: + enabled = block.enabled + + if block.name in enabled_block: + enabled = True + if block.name in disabled_block: + enabled = False + + if not enabled: + continue + # print(text) + # print("---", block.name) + text = block.process(text=text, guess_lang=guess) + # print("---") + # print(text) + + return text diff --git a/modules/core/tn/__init__.py b/modules/core/tn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/SentenceSplitter.py b/modules/core/tools/SentenceSplitter.py similarity index 92% rename from modules/SentenceSplitter.py rename to modules/core/tools/SentenceSplitter.py index 847d717..ec00c64 100644 --- a/modules/SentenceSplitter.py +++ b/modules/core/tools/SentenceSplitter.py @@ -2,13 +2,14 @@ import zhon -from modules.models import get_tokenizer +from modules.core.models import zoo from modules.utils.detect_lang import guess_lang # 解析文本 并根据停止符号分割成句子 # 可以设置最大阈值,即如果分割片段小于这个阈值会与下一段合并 class SentenceSplitter: + # 分隔符 用于连接句子 sentence1 + SEP_TOKEN + sentence2 SEP_TOKEN = " " def __init__(self, threshold=100): @@ -17,9 +18,12 @@ def __init__(self, threshold=100): ), "Threshold must be greater than 0." self.sentence_threshold = threshold - self.tokenizer = get_tokenizer() + self.tokenizer = zoo.ChatTTS.get_tokenizer() - def count_tokens(self, text: str): + def len(self, text: str): + """ + Get the length of tokenized text. + """ return len(self.tokenizer.tokenize(text)) def parse(self, text: str): @@ -37,7 +41,7 @@ def merge_text_by_threshold(self, setences: list[str]): merged_sentences: list[str] = [] temp_sentence = "" for sentence in setences: - if len(temp_sentence) + len(sentence) < self.sentence_threshold: + if self.len(temp_sentence) + self.len(sentence) < self.sentence_threshold: temp_sentence += SentenceSplitter.SEP_TOKEN + sentence else: merged_sentences.append(temp_sentence) diff --git a/modules/core/tools/__init__.py b/modules/core/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/core/tools/misc.py b/modules/core/tools/misc.py new file mode 100644 index 0000000..b09703e --- /dev/null +++ b/modules/core/tools/misc.py @@ -0,0 +1,6 @@ +def to_number(value, t, default=0): + try: + number = t(value) + return number + except (ValueError, TypeError) as e: + return default diff --git a/modules/devices/devices.py b/modules/devices/devices.py index 3b201ee..8fc3e04 100644 --- a/modules/devices/devices.py +++ b/modules/devices/devices.py @@ -75,7 +75,10 @@ def get_target_device_id_or_memory_available_gpu(): def get_optimal_device_name(): - if config.runtime_env_vars.use_cpu == "all": + if config.runtime_env_vars.use_cpu is None: + config.runtime_env_vars.use_cpu = [] + + if "all" in config.runtime_env_vars.use_cpu: return "cpu" if torch.cuda.is_available(): @@ -92,6 +95,9 @@ def get_optimal_device(): def get_device_for(task): + if config.runtime_env_vars.use_cpu is None: + config.runtime_env_vars.use_cpu = [] + if ( task in config.runtime_env_vars.use_cpu or "all" in config.runtime_env_vars.use_cpu diff --git a/modules/finetune/train_speaker.py b/modules/finetune/train_speaker.py index be5077e..9df91c3 100644 --- a/modules/finetune/train_speaker.py +++ b/modules/finetune/train_speaker.py @@ -209,9 +209,9 @@ def train_speaker_embeddings( import numpy as np from modules import config + from modules.core.speaker import Speaker from modules.devices import devices from modules.models import load_chat_tts - from modules.speaker import Speaker config.runtime_env_vars.no_half = True config.runtime_env_vars.use_cpu = [] diff --git a/modules/gradio_dcls_fix.py b/modules/fixs/gradio_dcls_fix.py similarity index 100% rename from modules/gradio_dcls_fix.py rename to modules/fixs/gradio_dcls_fix.py diff --git a/modules/generate_audio.py b/modules/generate_audio.py deleted file mode 100644 index ea3404e..0000000 --- a/modules/generate_audio.py +++ /dev/null @@ -1,227 +0,0 @@ -import gc -import logging -from typing import Generator, Union - -import numpy as np -import torch - -from modules import config, models -from modules.ChatTTS import ChatTTS -from modules.ChatTTSInfer import ChatTTSInfer -from modules.devices import devices -from modules.speaker import Speaker -from modules.utils.cache import conditional_cache -from modules.utils.SeedContext import SeedContext - -logger = logging.getLogger(__name__) - -SAMPLE_RATE = 24000 - - -def generate_audio( - text: str, - temperature: float = 0.3, - top_P: float = 0.7, - top_K: float = 20, - spk: Union[int, Speaker] = -1, - infer_seed: int = -1, - use_decoder: bool = True, - prompt1: str = "", - prompt2: str = "", - prefix: str = "", -): - (sample_rate, wav) = generate_audio_batch( - [text], - temperature=temperature, - top_P=top_P, - top_K=top_K, - spk=spk, - infer_seed=infer_seed, - use_decoder=use_decoder, - prompt1=prompt1, - prompt2=prompt2, - prefix=prefix, - )[0] - - return (sample_rate, wav) - - -def parse_infer_spk_emb( - spk: Union[int, Speaker] = -1, -): - if isinstance(spk, int): - logger.debug(("[spk_from_seed]", spk)) - return Speaker.from_seed(spk).emb - elif isinstance(spk, Speaker): - logger.debug(("[spk_from_file]", spk.name)) - if not isinstance(spk.emb, torch.Tensor): - raise ValueError("spk.pt is broken, please retrain the model.") - return spk.emb - else: - logger.warn( - f"spk must be int or Speaker, but: <{type(spk)}> {spk}, wiil set to default voice" - ) - return Speaker.from_seed(2).emb - - -@torch.inference_mode() -def generate_audio_batch( - texts: list[str], - temperature: float = 0.3, - top_P: float = 0.7, - top_K: float = 20, - spk: Union[int, Speaker] = -1, - infer_seed: int = -1, - use_decoder: bool = True, - prompt1: str = "", - prompt2: str = "", - prefix: str = "", -): - chat_tts = models.load_chat_tts() - spk_emb = parse_infer_spk_emb( - spk=spk, - ) - logger.debug( - ( - "[generate_audio_batch]", - { - "texts": texts, - "temperature": temperature, - "top_P": top_P, - "top_K": top_K, - "spk": spk, - "infer_seed": infer_seed, - "use_decoder": use_decoder, - "prompt1": prompt1, - "prompt2": prompt2, - "prefix": prefix, - }, - ) - ) - - with SeedContext(infer_seed, True): - infer = ChatTTSInfer(instance=chat_tts) - wavs = infer.generate_audio( - text=texts, - spk_emb=spk_emb, - temperature=temperature, - top_K=top_K, - top_P=top_P, - use_decoder=use_decoder, - prompt1=prompt1, - prompt2=prompt2, - prefix=prefix, - ) - - if config.auto_gc: - devices.torch_gc() - gc.collect() - - return [(SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)) for wav in wavs] - - -# TODO: generate_audio_stream 也应该支持 lru cache -@torch.inference_mode() -def generate_audio_stream( - text: str, - temperature: float = 0.3, - top_P: float = 0.7, - top_K: float = 20, - spk: Union[int, Speaker] = -1, - infer_seed: int = -1, - use_decoder: bool = True, - prompt1: str = "", - prompt2: str = "", - prefix: str = "", -) -> Generator[tuple[int, np.ndarray], None, None]: - chat_tts = models.load_chat_tts() - texts = [text] - spk_emb = parse_infer_spk_emb( - spk=spk, - ) - logger.debug( - ( - "[generate_audio_stream]", - { - "text": text, - "temperature": temperature, - "top_P": top_P, - "top_K": top_K, - "spk": spk, - "infer_seed": infer_seed, - "use_decoder": use_decoder, - "prompt1": prompt1, - "prompt2": prompt2, - "prefix": prefix, - }, - ) - ) - - with SeedContext(infer_seed, True): - infer = ChatTTSInfer(instance=chat_tts) - wavs_gen = infer.generate_audio_stream( - text=texts, - spk_emb=spk_emb, - temperature=temperature, - top_K=top_K, - top_P=top_P, - use_decoder=use_decoder, - prompt1=prompt1, - prompt2=prompt2, - prefix=prefix, - ) - - for wav in wavs_gen: - yield [SAMPLE_RATE, np.array(wav).flatten().astype(np.float32)] - - if config.auto_gc: - devices.torch_gc() - gc.collect() - - return - - -lru_cache_enabled = False - - -def setup_lru_cache(): - global generate_audio_batch - global lru_cache_enabled - - if lru_cache_enabled: - return - lru_cache_enabled = True - - def should_cache(*args, **kwargs): - spk_seed = kwargs.get("spk", -1) - infer_seed = kwargs.get("infer_seed", -1) - return spk_seed != -1 and infer_seed != -1 - - lru_size = config.runtime_env_vars.lru_size - if isinstance(lru_size, int): - generate_audio_batch = conditional_cache(lru_size, should_cache)( - generate_audio_batch - ) - logger.info(f"LRU cache enabled with size {lru_size}") - else: - logger.debug(f"LRU cache failed to enable, invalid size {lru_size}") - - -if __name__ == "__main__": - import soundfile as sf - - # 测试batch生成 - inputs = ["你好[lbreak]", "再见[lbreak]", "长度不同的文本片段[lbreak]"] - outputs = generate_audio_batch(inputs, spk=5, infer_seed=42) - - for i, (sample_rate, wav) in enumerate(outputs): - print(i, sample_rate, wav.shape) - - sf.write(f"batch_{i}.wav", wav, sample_rate, format="wav") - - # 单独生成 - for i, text in enumerate(inputs): - sample_rate, wav = generate_audio(text, spk=5, infer_seed=42) - print(i, sample_rate, wav.shape) - - sf.write(f"one_{i}.wav", wav, sample_rate, format="wav") diff --git a/modules/models_setup.py b/modules/models_setup.py index 8feaa14..60f5225 100644 --- a/modules/models_setup.py +++ b/modules/models_setup.py @@ -1,10 +1,9 @@ import argparse import logging -from modules import generate_audio +from modules.core.models import zoo from modules.devices import devices from modules.Enhancer.ResembleEnhance import load_enhancer -from modules.models import load_chat_tts from modules.utils import env @@ -37,6 +36,7 @@ def setup_model_args(parser: argparse.ArgumentParser): type=str.lower, choices=["all", "chattts", "enhancer", "trainer"], ) + # TODO: tts_pipeline 引入之后还不支持从这里配置 parser.add_argument( "--lru_size", type=int, @@ -66,13 +66,16 @@ def process_model_args(args: argparse.Namespace): debug_generate = env.get_and_update_env(args, "debug_generate", False, bool) preload_models = env.get_and_update_env(args, "preload_models", False, bool) - generate_audio.setup_lru_cache() + # TODO: 需要等 zoo 模块实现 + # generate_audio.setup_lru_cache() devices.reset_device() devices.first_time_calculation() - if debug_generate: - generate_audio.logger.setLevel(logging.DEBUG) + zoo.zoo_config.debug_generate = debug_generate if preload_models: - load_chat_tts() + """ + TODO: 需要增强 zoo + """ + zoo.ChatTTS.load_chat_tts() load_enhancer() diff --git a/modules/refiner.py b/modules/refiner.py index 679498a..5789ee2 100644 --- a/modules/refiner.py +++ b/modules/refiner.py @@ -1,11 +1,10 @@ from typing import Generator -import numpy as np import torch -from modules import config, models -from modules.ChatTTSInfer import ChatTTSInfer -from modules.SentenceSplitter import SentenceSplitter +from modules.core.models import zoo +from modules.core.models.zoo.ChatTTSInfer import ChatTTSInfer +from modules.core.tools.SentenceSplitter import SentenceSplitter from modules.utils.SeedContext import SeedContext @@ -21,7 +20,7 @@ def refine_text( max_new_token=384, spliter_threshold=300, ) -> str: - chat_tts = models.load_chat_tts() + chat_tts = zoo.ChatTTS.load_chat_tts() spliter = SentenceSplitter(spliter_threshold) sentences = spliter.parse(text) diff --git a/modules/utils/zh_normalization/README.md b/modules/repos_static/zh_normalization/README.md similarity index 100% rename from modules/utils/zh_normalization/README.md rename to modules/repos_static/zh_normalization/README.md diff --git a/modules/utils/zh_normalization/__init__.py b/modules/repos_static/zh_normalization/__init__.py similarity index 100% rename from modules/utils/zh_normalization/__init__.py rename to modules/repos_static/zh_normalization/__init__.py diff --git a/modules/utils/zh_normalization/char_convert.py b/modules/repos_static/zh_normalization/char_convert.py similarity index 100% rename from modules/utils/zh_normalization/char_convert.py rename to modules/repos_static/zh_normalization/char_convert.py diff --git a/modules/utils/zh_normalization/chronology.py b/modules/repos_static/zh_normalization/chronology.py similarity index 100% rename from modules/utils/zh_normalization/chronology.py rename to modules/repos_static/zh_normalization/chronology.py diff --git a/modules/utils/zh_normalization/constants.py b/modules/repos_static/zh_normalization/constants.py similarity index 100% rename from modules/utils/zh_normalization/constants.py rename to modules/repos_static/zh_normalization/constants.py diff --git a/modules/utils/zh_normalization/num.py b/modules/repos_static/zh_normalization/num.py similarity index 100% rename from modules/utils/zh_normalization/num.py rename to modules/repos_static/zh_normalization/num.py diff --git a/modules/utils/zh_normalization/phonecode.py b/modules/repos_static/zh_normalization/phonecode.py similarity index 100% rename from modules/utils/zh_normalization/phonecode.py rename to modules/repos_static/zh_normalization/phonecode.py diff --git a/modules/utils/zh_normalization/quantifier.py b/modules/repos_static/zh_normalization/quantifier.py similarity index 100% rename from modules/utils/zh_normalization/quantifier.py rename to modules/repos_static/zh_normalization/quantifier.py diff --git a/modules/utils/zh_normalization/text_normlization.py b/modules/repos_static/zh_normalization/text_normlization.py similarity index 100% rename from modules/utils/zh_normalization/text_normlization.py rename to modules/repos_static/zh_normalization/text_normlization.py diff --git a/modules/speaker.py b/modules/speaker.py index 1556d0a..e69fa9a 100644 --- a/modules/speaker.py +++ b/modules/speaker.py @@ -1,188 +1,3 @@ -import os -import uuid -from typing import Union - -import torch -from box import Box - -from modules import models -from modules.utils.SeedContext import SeedContext - - -def create_speaker_from_seed(seed): - chat_tts = models.load_chat_tts() - with SeedContext(seed, True): - emb = chat_tts._sample_random_speaker() - return emb - - -class Speaker: - @staticmethod - def from_file(file_like): - speaker = torch.load(file_like, map_location=torch.device("cpu")) - speaker.fix() - return speaker - - @staticmethod - def from_tensor(tensor): - speaker = Speaker(seed_or_tensor=-2) - speaker.emb = tensor - return speaker - - @staticmethod - def from_seed(seed: int): - speaker = Speaker(seed_or_tensor=seed) - speaker.emb = create_speaker_from_seed(seed) - return speaker - - def __init__( - self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe="" - ): - self.id = uuid.uuid4() - self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor - self.name = name - self.gender = gender - self.describe = describe - self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor - - # TODO replace emb => tokens - self.tokens = [] - - def to_json(self, with_emb=False): - return Box( - **{ - "id": str(self.id), - "seed": self.seed, - "name": self.name, - "gender": self.gender, - "describe": self.describe, - "emb": self.emb.tolist() if with_emb else None, - } - ) - - def fix(self): - is_update = False - if "id" not in self.__dict__: - setattr(self, "id", uuid.uuid4()) - is_update = True - if "seed" not in self.__dict__: - setattr(self, "seed", -2) - is_update = True - if "name" not in self.__dict__: - setattr(self, "name", "") - is_update = True - if "gender" not in self.__dict__: - setattr(self, "gender", "*") - is_update = True - if "describe" not in self.__dict__: - setattr(self, "describe", "") - is_update = True - - return is_update - - def __hash__(self): - return hash(str(self.id)) - - def __eq__(self, other): - if not isinstance(other, Speaker): - return False - return str(self.id) == str(other.id) - - -# 每个speaker就是一个 emb 文件 .pt -# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker -# 可以 用 seed 创建一个 speaker -# 可以 刷新列表 读取所有 speaker -# 可以列出所有 speaker -class SpeakerManager: - def __init__(self): - self.speakers = {} - self.speaker_dir = "./data/speakers/" - self.refresh_speakers() - - def refresh_speakers(self): - self.speakers = {} - for speaker_file in os.listdir(self.speaker_dir): - if speaker_file.endswith(".pt"): - self.speakers[speaker_file] = Speaker.from_file( - self.speaker_dir + speaker_file - ) - # 检查是否有被删除的,同步到 speakers - for fname, spk in self.speakers.items(): - if not os.path.exists(self.speaker_dir + fname): - del self.speakers[fname] - - def list_speakers(self) -> list[Speaker]: - return list(self.speakers.values()) - - def create_speaker_from_seed(self, seed, name="", gender="", describe=""): - if name == "": - name = seed - filename = name + ".pt" - speaker = Speaker(seed, name=name, gender=gender, describe=describe) - speaker.emb = create_speaker_from_seed(seed) - torch.save(speaker, self.speaker_dir + filename) - self.refresh_speakers() - return speaker - - def create_speaker_from_tensor( - self, tensor, filename="", name="", gender="", describe="" - ): - if filename == "": - filename = name - speaker = Speaker( - seed_or_tensor=-2, name=name, gender=gender, describe=describe - ) - if isinstance(tensor, torch.Tensor): - speaker.emb = tensor - if isinstance(tensor, list): - speaker.emb = torch.tensor(tensor) - torch.save(speaker, self.speaker_dir + filename + ".pt") - self.refresh_speakers() - return speaker - - def get_speaker(self, name) -> Union[Speaker, None]: - for speaker in self.speakers.values(): - if speaker.name == name: - return speaker - return None - - def get_speaker_by_id(self, id) -> Union[Speaker, None]: - for speaker in self.speakers.values(): - if str(speaker.id) == str(id): - return speaker - return None - - def get_speaker_filename(self, id: str): - filename = None - for fname, spk in self.speakers.items(): - if str(spk.id) == str(id): - filename = fname - break - return filename - - def update_speaker(self, speaker: Speaker): - filename = None - for fname, spk in self.speakers.items(): - if str(spk.id) == str(speaker.id): - filename = fname - break - - if filename: - torch.save(speaker, self.speaker_dir + filename) - self.refresh_speakers() - return speaker - else: - raise ValueError("Speaker not found for update") - - def save_all(self): - for speaker in self.speakers.values(): - filename = self.get_speaker_filename(speaker.id) - torch.save(speaker, self.speaker_dir + filename) - # self.refresh_speakers() - - def __len__(self): - return len(self.speakers) - - -speaker_mgr = SpeakerManager() +# 由于 speaker 是保留了对象状态的... 所以需要保留这个模块 +# FIXME: 重构 speaker +from modules.core.speaker import Speaker diff --git a/modules/ssml.py b/modules/ssml.py deleted file mode 100644 index 578fd4a..0000000 --- a/modules/ssml.py +++ /dev/null @@ -1,63 +0,0 @@ -import logging -import random -from typing import Any, Dict, List - -import numpy as np -from lxml import etree - -from modules.data import styles_mgr -from modules.speaker import speaker_mgr - -logger = logging.getLogger(__name__) - - -def expand_spk(attrs: dict): - input_spk = attrs.get("spk", "") - if isinstance(input_spk, int): - return - if isinstance(input_spk, str) and input_spk.isdigit(): - attrs.update({"spk": int(input_spk)}) - return - try: - speaker = speaker_mgr.get_speaker(input_spk) - attrs.update({"spk": speaker}) - except Exception as e: - logger.error(f"apply style failed, {e}") - - -def expand_style(attrs: dict): - if attrs.get("style", "") != "": - try: - params = styles_mgr.find_params_by_name(str(attrs["style"])) - attrs.update(params) - except Exception as e: - logger.error(f"apply style failed, {e}") - - -def merge_prompt(attrs: dict, elem): - - def attr_num(attrs: Dict[str, Any], k: str, min_value: int, max_value: int): - val = elem.get(k, attrs.get(k, "")) - if val == "": - return - if val == "max": - val = max_value - if val == "min": - val = min_value - val = np.clip(int(val), min_value, max_value) - if "prefix" not in attrs or attrs["prefix"] == None: - attrs["prefix"] = "" - attrs["prefix"] += " " + f"[{k}_{val}]" - - attr_num(attrs, "oral", 0, 9) - attr_num(attrs, "speed", 0, 9) - attr_num(attrs, "laugh", 0, 2) - attr_num(attrs, "break", 0, 7) - - -def apply_random_seed(attrs: dict): - seed = attrs.get("seed", "") - if seed == "random" or seed == "rand": - seed = random.randint(0, 2**32 - 1) - attrs["seed"] = seed - logger.info(f"random seed: {seed}") diff --git a/modules/synthesize_audio.py b/modules/synthesize_audio.py deleted file mode 100644 index 0e49f90..0000000 --- a/modules/synthesize_audio.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Union - -from modules.SentenceSplitter import SentenceSplitter -from modules.speaker import Speaker -from modules.ssml_parser.SSMLParser import SSMLSegment -from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments -from modules.utils import audio - - -def synthesize_audio( - text: str, - temperature: float = 0.3, - top_P: float = 0.7, - top_K: float = 20, - spk: Union[int, Speaker] = -1, - infer_seed: int = -1, - use_decoder: bool = True, - prompt1: str = "", - prompt2: str = "", - prefix: str = "", - batch_size: int = 1, - spliter_threshold: int = 100, - end_of_sentence="", -): - spliter = SentenceSplitter(spliter_threshold) - sentences = spliter.parse(text) - - text_segments = [ - SSMLSegment( - text=s, - params={ - "temperature": temperature, - "top_P": top_P, - "top_K": top_K, - "spk": spk, - "infer_seed": infer_seed, - "use_decoder": use_decoder, - "prompt1": prompt1, - "prompt2": prompt2, - "prefix": prefix, - }, - ) - for s in sentences - ] - synthesizer = SynthesizeSegments( - batch_size=batch_size, eos=end_of_sentence, spliter_thr=spliter_threshold - ) - audio_segments = synthesizer.synthesize_segments(text_segments) - - combined_audio = combine_audio_segments(audio_segments) - - return audio.pydub_to_np(combined_audio) diff --git a/modules/synthesize_stream.py b/modules/synthesize_stream.py deleted file mode 100644 index f14716b..0000000 --- a/modules/synthesize_stream.py +++ /dev/null @@ -1,120 +0,0 @@ -import io -from typing import Generator, Optional, Union - -import numpy as np -import torch -from cachetools import LRUCache -from cachetools import keys as cache_keys - -from modules import generate_audio as generate -from modules.SentenceSplitter import SentenceSplitter -from modules.speaker import Speaker - - -def handle_chunks( - wav_gen: np.ndarray, - wav_gen_prev: Optional[np.ndarray], - wav_overlap: Optional[np.ndarray], - overlap_len: int, -): - """Handle chunk formatting in streaming mode""" - wav_chunk = wav_gen[:-overlap_len] - if wav_gen_prev is not None: - wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] - if wav_overlap is not None: - # cross fade the overlap section - if overlap_len > len(wav_chunk): - # wav_chunk is smaller than overlap_len, pass on last wav_gen - if wav_gen_prev is not None: - wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :] - else: - # not expecting will hit here as problem happens on last chunk - wav_chunk = wav_gen[-overlap_len:] - return wav_chunk, wav_gen, None - else: - crossfade_wav = wav_chunk[:overlap_len] - crossfade_wav = crossfade_wav * np.linspace(0.0, 1.0, overlap_len) - wav_chunk[:overlap_len] = wav_overlap * np.linspace(1.0, 0.0, overlap_len) - wav_chunk[:overlap_len] += crossfade_wav - - wav_overlap = wav_gen[-overlap_len:] - wav_gen_prev = wav_gen - return wav_chunk, wav_gen_prev, wav_overlap - - -cache = LRUCache(maxsize=128) - - -def synthesize_stream( - text: str, - temperature: float = 0.3, - top_P: float = 0.7, - top_K: float = 20, - spk: Union[int, Speaker] = -1, - infer_seed: int = -1, - use_decoder: bool = True, - prompt1: str = "", - prompt2: str = "", - prefix: str = "", - spliter_threshold: int = 100, - end_of_sentence: str = "", - overlap_wav_len: int = 1024, -) -> Generator[tuple[int, np.ndarray], None, None]: - cachekey = cache_keys.hashkey( - text, - temperature, - top_P, - top_K, - spk=spk if isinstance(spk, int) else spk.id, - infer_seed=infer_seed, - use_decoder=use_decoder, - prompt1=prompt1, - prompt2=prompt2, - prefix=prefix, - spliter_threshold=spliter_threshold, - end_of_sentence=end_of_sentence, - overlap_wav_len=overlap_wav_len, - ) - - if cachekey in cache: - for sr, wav in cache[cachekey]: - yield sr, wav - return - - spliter = SentenceSplitter(spliter_threshold) - sentences = spliter.parse(text) - - wav_gen_prev = None - wav_overlap = None - - total_wavs = [] - for sentence in sentences: - wav_gen = generate.generate_audio_stream( - text=sentence + end_of_sentence, - temperature=temperature, - top_P=top_P, - top_K=top_K, - spk=spk, - infer_seed=infer_seed, - use_decoder=use_decoder, - prompt1=prompt1, - prompt2=prompt2, - prefix=prefix, - ) - - for sr, wav in wav_gen: - total_wavs.append((sr, wav)) - yield sr, wav - - # NOTE: 作用很微妙对质量有一点改善 - # for sr, wav in wav_gen: - # total_wavs.append((sr, wav)) - # wav_gen = ( - # wav if wav_gen_prev is None else np.concatenate([wav_gen_prev, wav]) - # ) - # wav_chunk, wav_gen_prev, wav_overlap = handle_chunks( - # wav_gen, wav_gen_prev, wav_overlap, overlap_wav_len - # ) - # yield sr, wav_chunk - - cache[cachekey] = [(sr, np.concatenate([wav for sr, wav in total_wavs]))] diff --git a/modules/utils/audio.py b/modules/utils/audio_utils.py similarity index 96% rename from modules/utils/audio.py rename to modules/utils/audio_utils.py index 59dfd2a..fcee0e0 100644 --- a/modules/utils/audio.py +++ b/modules/utils/audio_utils.py @@ -120,6 +120,13 @@ def apply_normalize( return pydub_to_np(segment) +def silence_np( + duration_s: float, +) -> tuple[int, np.ndarray]: + silence = AudioSegment.silent(duration=duration_s * 1000) + return pydub_to_np(silence) + + if __name__ == "__main__": input_file = sys.argv[1] diff --git a/modules/webui/finetune/ft_ui_utils.py b/modules/webui/finetune/ft_ui_utils.py index 5773722..5a1a000 100644 --- a/modules/webui/finetune/ft_ui_utils.py +++ b/modules/webui/finetune/ft_ui_utils.py @@ -2,7 +2,7 @@ import subprocess from typing import IO, Union -from modules.speaker import Speaker, speaker_mgr +from modules.core.speaker import Speaker, speaker_mgr def get_datasets_dir(): diff --git a/modules/webui/finetune/speaker_ft_tab.py b/modules/webui/finetune/speaker_ft_tab.py index 0353fe6..38cabfa 100644 --- a/modules/webui/finetune/speaker_ft_tab.py +++ b/modules/webui/finetune/speaker_ft_tab.py @@ -1,8 +1,8 @@ import gradio as gr +from modules.core.models import zoo +from modules.core.speaker import speaker_mgr from modules.Enhancer.ResembleEnhance import unload_enhancer -from modules.models import unload_chat_tts -from modules.speaker import speaker_mgr from modules.webui import webui_config from modules.webui.webui_utils import get_speaker_names @@ -16,7 +16,7 @@ def __init__(self): self.status_str = "idle" def unload_main_thread_models(self): - unload_chat_tts() + zoo.ChatTTS.unload_chat_tts() unload_enhancer() def run( diff --git a/modules/webui/speaker/speaker_creator.py b/modules/webui/speaker/speaker_creator.py index 82810e2..9faf838 100644 --- a/modules/webui/speaker/speaker_creator.py +++ b/modules/webui/speaker/speaker_creator.py @@ -3,8 +3,7 @@ import gradio as gr import torch -from modules.models import load_chat_tts -from modules.speaker import Speaker +from modules.core.speaker import Speaker from modules.utils.hf import spaces from modules.utils.rng import np_rng from modules.utils.SeedContext import SeedContext diff --git a/modules/webui/speaker/speaker_editor.py b/modules/webui/speaker/speaker_editor.py index 51bbb48..d4f3abb 100644 --- a/modules/webui/speaker/speaker_editor.py +++ b/modules/webui/speaker/speaker_editor.py @@ -3,7 +3,7 @@ import gradio as gr import torch -from modules.speaker import Speaker +from modules.core.speaker import Speaker from modules.utils.hf import spaces from modules.webui import webui_config from modules.webui.webui_utils import tts_generate diff --git a/modules/webui/speaker/speaker_merger.py b/modules/webui/speaker/speaker_merger.py index 5bbc662..2bf4b69 100644 --- a/modules/webui/speaker/speaker_merger.py +++ b/modules/webui/speaker/speaker_merger.py @@ -1,37 +1,36 @@ -import io import tempfile import gradio as gr import torch -from modules.speaker import Speaker, speaker_mgr +from modules.core.speaker import Speaker, speaker_mgr from modules.utils.hf import spaces from modules.webui import webui_config, webui_utils -from modules.webui.webui_utils import get_speakers, tts_generate +from modules.webui.webui_utils import tts_generate -def spk_to_tensor(spk): +def spk_to_tensor(spk: str): spk = spk.split(" : ")[1].strip() if " : " in spk else spk if spk == "None" or spk == "": return None return speaker_mgr.get_speaker(spk).emb -def get_speaker_show_name(spk): +def get_speaker_show_name(spk: str): if spk.gender == "*" or spk.gender == "": return spk.name return f"{spk.gender} : {spk.name}" def merge_spk( - spk_a, - spk_a_w, - spk_b, - spk_b_w, - spk_c, - spk_c_w, - spk_d, - spk_d_w, + spk_a: str, + spk_a_w: float, + spk_b: str, + spk_b_w: float, + spk_c: str, + spk_c_w: float, + spk_d: str, + spk_d_w: float, ): tensor_a = spk_to_tensor(spk_a) tensor_b = spk_to_tensor(spk_b) diff --git a/modules/webui/ssml/podcast_tab.py b/modules/webui/ssml/podcast_tab.py index bd051b7..a37c9bc 100644 --- a/modules/webui/ssml/podcast_tab.py +++ b/modules/webui/ssml/podcast_tab.py @@ -2,7 +2,6 @@ import pandas as pd import torch -from modules.normalization import text_normalize from modules.utils.hf import spaces from modules.webui import webui_config, webui_utils @@ -19,7 +18,7 @@ def merge_dataframe_to_ssml(msg, spk, style, df: pd.DataFrame): spk = row.get("speaker") style = row.get("style") - text = text_normalize(text) + text = webui_utils.text_normalize(text) if text.strip() == "": continue diff --git a/modules/webui/ssml/spliter_tab.py b/modules/webui/ssml/spliter_tab.py index e7dc228..1d85653 100644 --- a/modules/webui/ssml/spliter_tab.py +++ b/modules/webui/ssml/spliter_tab.py @@ -1,10 +1,14 @@ import gradio as gr import torch -from modules.normalization import text_normalize from modules.utils.hf import spaces from modules.webui import webui_utils -from modules.webui.webui_utils import get_speakers, get_styles, split_long_text +from modules.webui.webui_utils import ( + get_speakers, + get_styles, + split_long_text, + text_normalize, +) # NOTE: 因为 text_normalize 需要使用 tokenizer diff --git a/modules/webui/webui_utils.py b/modules/webui/webui_utils.py index c636cac..34cbf6c 100644 --- a/modules/webui/webui_utils.py +++ b/modules/webui/webui_utils.py @@ -6,22 +6,24 @@ import torch.profiler from modules import refiner -from modules.api.impl.handler.SSMLHandler import SSMLHandler -from modules.api.impl.handler.TTSHandler import TTSHandler -from modules.api.impl.model.audio_model import AdjustConfig -from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig -from modules.api.impl.model.enhancer_model import EnhancerConfig from modules.api.utils import calc_spk_style +from modules.core.handler.datacls.audio_model import AdjustConfig +from modules.core.handler.datacls.chattts_model import ChatTTSConfig, InferConfig +from modules.core.handler.datacls.enhancer_model import EnhancerConfig +from modules.core.handler.SSMLHandler import SSMLHandler +from modules.core.handler.TTSHandler import TTSHandler +from modules.core.speaker import Speaker, speaker_mgr +from modules.core.ssml.SSMLParser import SSMLBreak, SSMLSegment +from modules.core.tn import ChatTtsTN +from modules.core.tools.SentenceSplitter import SentenceSplitter from modules.data import styles_mgr from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance -from modules.normalization import text_normalize -from modules.SentenceSplitter import SentenceSplitter -from modules.speaker import Speaker, speaker_mgr -from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLSegment, create_ssml_parser -from modules.utils import audio +from modules.utils import audio_utils from modules.utils.hf import spaces from modules.webui import webui_config +from modules.core.ssml.SSMLParser import create_ssml_v01_parser + def get_speakers(): return speaker_mgr.list_speakers() @@ -109,7 +111,7 @@ def synthesize_ssml( if ssml == "": raise gr.Error("SSML is empty, please input some SSML") - parser = create_ssml_parser() + parser = create_ssml_v01_parser() segments = parser.parse(ssml) max_len = webui_config.ssml_max segments = segments_length_limit(segments, max_len) @@ -123,6 +125,8 @@ def synthesize_ssml( eos=eos, # NOTE: SSML not support `infer_seed` contorl # seed=42, + # NOTE: 开启以支持 track_tqdm + sync_gen=True, ) adjust_config = AdjustConfig( pitch=pitch, @@ -143,12 +147,12 @@ def synthesize_ssml( enhancer_config=enhancer_config, ) - audio_data, sr = handler.enqueue() + sample_rate, audio_data = handler.enqueue() # NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式 - audio_data = audio.audio_to_int16(audio_data) + audio_data = audio_utils.audio_to_int16(audio_data) - return sr, audio_data + return sample_rate, audio_data # @torch.inference_mode() @@ -234,6 +238,8 @@ def tts_generate( spliter_threshold=spliter_thr, eos=eos, seed=infer_seed, + # NOTE: 开启以支持 track_tqdm + sync_gen=True, ) adjust_config = AdjustConfig( pitch=pitch, @@ -256,13 +262,17 @@ def tts_generate( enhancer_config=enhancer_config, ) - audio_data, sample_rate = handler.enqueue() + sample_rate, audio_data = handler.enqueue() # NOTE: 这里必须要加,不然 gradio 没法解析成 mp3 格式 - audio_data = audio.audio_to_int16(audio_data) + audio_data = audio_utils.audio_to_int16(audio_data) return sample_rate, audio_data +def text_normalize(text: str) -> str: + return ChatTtsTN.ChatTtsTN.normalize(text) + + @torch.inference_mode() @spaces.GPU(duration=120) def refine_text( @@ -298,6 +308,6 @@ def split_long_text(long_text_input, spliter_threshold=100, eos=""): sentences = [text_normalize(s) + eos for s in sentences] data = [] for i, text in enumerate(sentences): - token_length = spliter.count_tokens(text) + token_length = spliter.len(text) data.append([i, text, token_length]) return data diff --git a/playground/client.mjs b/playground/client.mjs index 245277d..1502224 100644 --- a/playground/client.mjs +++ b/playground/client.mjs @@ -43,6 +43,7 @@ class APIClient { prefix = "", bs = 8, thr = 100, + no_cache = false, stream = false, }) { const params = new URLSearchParams({ @@ -59,6 +60,7 @@ class APIClient { prefix, bs, thr, + no_cache, stream, }); // return `${this.client.defaults.baseURL}v1/tts?${params.toString()}`; diff --git a/playground/index.html b/playground/index.html index d2be0fc..2f02a78 100644 --- a/playground/index.html +++ b/playground/index.html @@ -5,7 +5,7 @@ + href="data:image/svg+xml,🍦"> ChatTTS Forge Playground