diff --git a/modules/ChatTTS/ChatTTS/core.py b/modules/ChatTTS/ChatTTS/core.py index 91e8a61..17ef9a8 100644 --- a/modules/ChatTTS/ChatTTS/core.py +++ b/modules/ChatTTS/ChatTTS/core.py @@ -647,6 +647,7 @@ def _infer_code( stream=stream, context=self.context, stream_chunk_size=params.stream_chunk_size, + ensure_non_empty=params.ensure_non_empty, ) del emb, input_ids diff --git a/modules/ChatTTS/ChatTTS/model/gpt.py b/modules/ChatTTS/ChatTTS/model/gpt.py index 47fb52f..611bf9c 100644 --- a/modules/ChatTTS/ChatTTS/model/gpt.py +++ b/modules/ChatTTS/ChatTTS/model/gpt.py @@ -506,6 +506,10 @@ def generate( del logits + if i == 0: + # when i == 0, we want to ensure that the first token is not eos_token + scores[:, eos_token] = 0 + idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) if not infer_text: @@ -541,27 +545,39 @@ def generate( if show_tqdm: pbar.close() self.logger.warn("regenerate in order to ensure non-empty") + del_all(attentions) + del_all(hiddens) + del ( + start_idx, + end_idx, + finish, + temperature, + attention_mask_cache, + past_key_values, + idx_next, + inputs_ids_tmp, + ) new_gen = self.generate( - emb, - inputs_ids, - old_temperature, - eos_token, - attention_mask, - max_new_token, - min_new_token, - logits_warpers, - logits_processors, - infer_text, - return_attn, - return_hidden, - stream, - show_tqdm, - ensure_non_empty, - context, + emb=emb, + inputs_ids=inputs_ids, + old_temperature=old_temperature, + eos_token=eos_token, + attention_mask=attention_mask, + max_new_token=max_new_token, + min_new_token=min_new_token, + logits_warpers=logits_warpers, + logits_processors=logits_processors, + infer_text=infer_text, + return_attn=return_attn, + return_hidden=return_hidden, + stream=stream, + show_tqdm=show_tqdm, + ensure_non_empty=ensure_non_empty, + context=context, ) for result in new_gen: yield result - return + return del inputs_ids inputs_ids = inputs_ids_tmp diff --git a/modules/api/impl/openai_api.py b/modules/api/impl/openai_api.py index 652801d..a98fa17 100644 --- a/modules/api/impl/openai_api.py +++ b/modules/api/impl/openai_api.py @@ -92,6 +92,7 @@ async def openai_speech_api( spliter_threshold=spliter_threshold, eos=eos, seed=seed, + stream=stream, ) adjust_config = AdjustConfig(speaking_rate=speed) enhancer_config = EnhancerConfig( diff --git a/modules/api/impl/ssml_api.py b/modules/api/impl/ssml_api.py index 50a0f27..f5c6711 100644 --- a/modules/api/impl/ssml_api.py +++ b/modules/api/impl/ssml_api.py @@ -24,6 +24,8 @@ class SSMLRequest(BaseModel): enhancer: EnhancerConfig = EnhancerConfig() adjuster: AdjustConfig = AdjustConfig() + stream: bool = False + async def synthesize_ssml_api( request: SSMLRequest = Body( @@ -35,6 +37,7 @@ async def synthesize_ssml_api( format = request.format.lower() batch_size = request.batch_size eos = request.eos + stream = request.stream spliter_thr = request.spliter_thr enhancer = request.enhancer adjuster = request.adjuster @@ -58,9 +61,7 @@ async def synthesize_ssml_api( ) infer_config = InferConfig( - batch_size=batch_size, - spliter_threshold=spliter_thr, - eos=eos, + batch_size=batch_size, spliter_threshold=spliter_thr, eos=eos, stream=stream ) adjust_config = adjuster enhancer_config = enhancer @@ -72,12 +73,18 @@ async def synthesize_ssml_api( enhancer_config=enhancer_config, ) - buffer = handler.enqueue_to_buffer(format=request.format) - - mime_type = f"audio/{format}" + media_type = f"audio/{format}" if format == AudioFormat.mp3: - mime_type = "audio/mpeg" - return StreamingResponse(buffer, media_type=mime_type) + media_type = "audio/mpeg" + + if stream: + gen = handler.enqueue_to_stream_with_request( + request=request, format=AudioFormat(format) + ) + return StreamingResponse(gen, media_type=media_type) + else: + buffer = handler.enqueue_to_buffer(format=AudioFormat(format)) + return StreamingResponse(buffer, media_type=media_type) except Exception as e: import logging diff --git a/modules/api/impl/tts_api.py b/modules/api/impl/tts_api.py index bb81838..07d911a 100644 --- a/modules/api/impl/tts_api.py +++ b/modules/api/impl/tts_api.py @@ -102,6 +102,7 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()): prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1) prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2) eos = params.eos or "" + stream = params.stream batch_size = int(params.bs) threshold = int(params.thr) @@ -120,6 +121,7 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()): spliter_threshold=threshold, eos=eos, seed=seed, + stream=stream, ) adjust_config = AdjustConfig( pitch=params.pitch, @@ -143,13 +145,7 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()): if params.format == "mp3": media_type = "audio/mpeg" - if params.stream: - if infer_config.batch_size != 1: - # 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略 - logger.warning( - f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1" - ) - + if stream: gen = handler.enqueue_to_stream_with_request( request=request, format=AudioFormat(params.format) ) diff --git a/modules/api/impl/xtts_v2_api.py b/modules/api/impl/xtts_v2_api.py index 66a806e..191c952 100644 --- a/modules/api/impl/xtts_v2_api.py +++ b/modules/api/impl/xtts_v2_api.py @@ -111,6 +111,7 @@ async def tts_to_audio(request: SynthesisRequest): spliter_threshold=XTTSV2.spliter_threshold, eos=XTTSV2.eos, seed=XTTSV2.infer_seed, + stream=False, ) adjust_config = AdjustConfig( speed_rate=XTTSV2.speed, @@ -164,6 +165,7 @@ async def tts_stream( spliter_threshold=XTTSV2.spliter_threshold, eos=XTTSV2.eos, seed=XTTSV2.infer_seed, + stream=True, ) adjust_config = AdjustConfig( speed_rate=XTTSV2.speed, diff --git a/modules/core/handler/AudioHandler.py b/modules/core/handler/AudioHandler.py index d71d4af..c31b9fb 100644 --- a/modules/core/handler/AudioHandler.py +++ b/modules/core/handler/AudioHandler.py @@ -97,8 +97,7 @@ def enqueue_to_stream(self, format: AudioFormat) -> Generator[bytes, None, None] async def enqueue_to_stream_with_request( self, request: Request, format: AudioFormat ) -> AsyncGenerator[bytes, None]: - buffer_gen = self.enqueue_to_stream(format=AudioFormat(format)) - for chunk in buffer_gen: + for chunk in self.enqueue_to_stream(format=AudioFormat(format)): disconnected = await request.is_disconnected() if disconnected: ChatTTSInfer.interrupt() diff --git a/modules/core/models/tts/ChatTtsModel.py b/modules/core/models/tts/ChatTtsModel.py index 1736edf..f0a864c 100644 --- a/modules/core/models/tts/ChatTtsModel.py +++ b/modules/core/models/tts/ChatTtsModel.py @@ -137,7 +137,11 @@ def _gen(): prompt2=prompt2, prefix=prefix, ) - audio_arr: list[NP_AUDIO] = [(sr, data[0]) for data in results] + audio_arr: list[NP_AUDIO] = [ + # NOTE: data[0] 的意思是 立体声 => mono audio + (sr, np.empty(0)) if data is None else (sr, data[0]) + for data in results + ] self.set_cache(segments=segments, context=context, value=audio_arr) return audio_arr @@ -157,7 +161,17 @@ def _gen() -> Generator[list[NP_AUDIO], None, None]: prefix=prefix, stream_chunk_size=chunk_size, ): - audio_arr: list[NP_AUDIO] = [(sr, data[0]) for data in results] + results = [ + ( + np.empty(0) + # None 应该是生成失败, size === 0 是生成结束 + if data is None or data.size == 0 + # NOTE: data[0] 的意思是 立体声 => mono audio + else data[0] + ) + for data in results + ] + audio_arr: list[NP_AUDIO] = [(sr, data) for data in results] yield audio_arr if audio_arr_buff is None: @@ -165,8 +179,8 @@ def _gen() -> Generator[list[NP_AUDIO], None, None]: else: for i, data in enumerate(results): sr1, before = audio_arr_buff[i] - buff = np.concatenate([before, data[0]], axis=0) + buff = np.concatenate([before, data], axis=0) audio_arr_buff[i] = (sr1, buff) self.set_cache(segments=segments, context=context, value=audio_arr_buff) - return _gen(infer) + return _gen() diff --git a/modules/core/models/zoo/ChatTTSInfer.py b/modules/core/models/zoo/ChatTTSInfer.py index 6f18980..d65b8f8 100644 --- a/modules/core/models/zoo/ChatTTSInfer.py +++ b/modules/core/models/zoo/ChatTTSInfer.py @@ -336,7 +336,7 @@ def generate_audio_stream( stream_chunk_size=96, use_decoder=True, ) -> Generator[list[np.ndarray], None, None]: - gen = self._generate_audio( + gen: Generator[list[np.ndarray], None, None] = self._generate_audio( text=text, spk_emb=spk_emb, top_P=top_P, @@ -354,9 +354,11 @@ def generate_audio_stream( def _generator(): with disable_tqdm(enabled=config.runtime_env_vars.off_tqdm): - for audio in gen: - if audio is not None: - yield audio + for audio_arr in gen: + # 如果为空就用空 array 填充 + # NOTE: 因为长度不一定,所以某个位置可能是 None + audio_arr = [np.empty(0) if i is None else i for i in audio_arr] + yield audio_arr return _generator() diff --git a/modules/core/pipeline/factory.py b/modules/core/pipeline/factory.py index c674c89..547ecc1 100644 --- a/modules/core/pipeline/factory.py +++ b/modules/core/pipeline/factory.py @@ -27,6 +27,7 @@ def _process_array( lambd = enhancer_config.lambd tau = enhancer_config.tau + sample_rate, audio_data = audio audio_data, sample_rate = apply_audio_enhance_full( audio_data=audio_data, sr=sample_rate, @@ -36,12 +37,12 @@ def _process_array( tau=tau, ) - return audio_data, sample_rate + return sample_rate, audio_data class AdjusterProcessor(AudioProcessor): def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUDIO: - sr, audio_data = audio + sample_rate, audio_data = audio adjust_config = context.adjust_config audio_data = audio_utils.apply_prosody_to_audio_data( @@ -49,9 +50,9 @@ def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUD rate=adjust_config.speed_rate, pitch=adjust_config.pitch, volume=adjust_config.volume_gain_db, - sr=sr, + sr=sample_rate, ) - return sr, audio_data + return sample_rate, audio_data class AudioNormalizer(AudioProcessor): @@ -68,7 +69,7 @@ def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUD class ChatTtsTNProcessor(TextProcessor): def process(self, segment: TTSSegment, context: TTSPipelineContext) -> TTSSegment: - segment.text = ChatTtsTN.normalize(segment.text, context.tts_config) + segment.text = ChatTtsTN.normalize(text=segment.text, config=context.tn_config) return segment diff --git a/modules/core/pipeline/generate/BatchGenerate.py b/modules/core/pipeline/generate/BatchGenerate.py index 223592e..4a92471 100644 --- a/modules/core/pipeline/generate/BatchGenerate.py +++ b/modules/core/pipeline/generate/BatchGenerate.py @@ -69,7 +69,7 @@ def generate_batch_stream(self, batch: TTSBatch): ): for audio, result in zip(batch.segments, results): sr, data = result - if len(data) == 0: + if data.size == 0: audio.done = True continue audio.data = np.concatenate([audio.data, data], axis=0) diff --git a/modules/core/pipeline/pipeline.py b/modules/core/pipeline/pipeline.py index fa925d3..fc19366 100644 --- a/modules/core/pipeline/pipeline.py +++ b/modules/core/pipeline/pipeline.py @@ -32,6 +32,9 @@ def set_model(self, model): def create_synth(self): chunker = TTSChunker(context=self.context) segments = chunker.segments() + # 其实这个在 chunker 之前调用好点...但是有副作用所以放在后面 + segments = [self.process_text(seg) for seg in segments] + synth = BatchSynth( input_segments=segments, context=self.context, model=self.model ) @@ -41,7 +44,8 @@ def generate(self) -> NP_AUDIO: synth = self.create_synth() synth.start_generate() synth.wait_done() - return synth.sr(), synth.read() + audio = synth.sr(), synth.read() + return self.process_np_audio(audio) def generate_stream(self) -> Generator[NP_AUDIO, None, None]: synth = self.create_synth() @@ -49,12 +53,18 @@ def generate_stream(self) -> Generator[NP_AUDIO, None, None]: while not synth.is_done(): data = synth.read() if data.size > 0: - yield synth.sr(), data + audio = synth.sr(), data + yield self.process_np_audio(audio) # TODO: replace with threading.Event sleep(0.1) data = synth.read() if data.size > 0: - yield synth.sr(), data + audio = synth.sr(), data + yield self.process_np_audio(audio) + + def process_np_audio(self, audio: NP_AUDIO) -> NP_AUDIO: + audio = self.process_audio(audio) + return self.ensure_audio_type(audio, "ndarray") def ensure_audio_type( self, audio: AUDIO, output_type: Literal["ndarray", "segment"] @@ -80,14 +90,14 @@ def _to_ndarray(self, audio: AUDIO) -> NP_AUDIO: return sr, audio return audio - def process_text(self, text: TTSSegment): + def process_text(self, seg: TTSSegment): for module in self.modules: if isinstance(module, TextProcessor): - text = module.process(text) - return text + seg = module.process(segment=seg, context=self.context) + return seg def process_audio(self, audio: AUDIO): for module in self.modules: if isinstance(module, AudioProcessor): - audio = module.process(audio) + audio = module.process(audio=audio, context=self.context) return audio diff --git a/modules/core/pipeline/processor.py b/modules/core/pipeline/processor.py index 4beda7d..3eb4cf8 100644 --- a/modules/core/pipeline/processor.py +++ b/modules/core/pipeline/processor.py @@ -17,7 +17,7 @@ def process(self, segment: TTSSegment, context: TTSPipelineContext) -> TTSSegmen class AudioProcessor: def process(self, audio: AUDIO, context: TTSPipelineContext) -> AUDIO: - if isinstance(audio, np.ndarray): + if isinstance(audio, tuple): return self._process_array(audio, context) elif isinstance(audio, AudioSegment): return self._process_segment(audio, context) @@ -25,7 +25,8 @@ def process(self, audio: AUDIO, context: TTSPipelineContext) -> AUDIO: raise ValueError("Unsupported audio type") def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUDIO: - segment = audio_utils.ndarray_to_segment(audio) + sr, data = audio + segment = audio_utils.ndarray_to_segment(ndarray=data, frame_rate=sr) processed_segment = self._process_segment(segment, context) return audio_utils.audiosegment_to_librosawav(processed_segment) diff --git a/modules/core/tn/TNPipeline.py b/modules/core/tn/TNPipeline.py index 8a9a751..8ce58b7 100644 --- a/modules/core/tn/TNPipeline.py +++ b/modules/core/tn/TNPipeline.py @@ -114,14 +114,20 @@ def guess_langs(self, text: str): guess = GuessLang(zh_or_en=zh_or_en, detected=detected) return guess - def _normalize(self, text: str, config: Optional[TNConfig] = None): + def _normalize(self, text: str, config: Optional[TNConfig] = TNConfig()): + if config is None: + config = TNConfig() + enabled_block = config.enabled if config.enabled else [] + disabled_block = config.disabled if config.disabled else [] + guess = self.guess_langs(text) for block in self.blocks: enabled = block.enabled - if config is not None and block.name in config.enabled: + + if block.name in enabled_block: enabled = True - if config is not None and block.name in config.disabled: + if block.name in disabled_block: enabled = False if not enabled: