Skip to content

Commit

Permalink
✨ 完善 stream mode
Browse files Browse the repository at this point in the history
- add ssml api stream mode
- fix stream bugs
- fix pipeline bugs
- fix ensure_non_empty issues
  • Loading branch information
zhzLuke96 committed Jul 12, 2024
1 parent fd93e79 commit 3095d4c
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 60 deletions.
1 change: 1 addition & 0 deletions modules/ChatTTS/ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def _infer_code(
stream=stream,
context=self.context,
stream_chunk_size=params.stream_chunk_size,
ensure_non_empty=params.ensure_non_empty,
)

del emb, input_ids
Expand Down
50 changes: 33 additions & 17 deletions modules/ChatTTS/ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,10 @@ def generate(

del logits

if i == 0:
# when i == 0, we want to ensure that the first token is not eos_token
scores[:, eos_token] = 0

idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)

if not infer_text:
Expand Down Expand Up @@ -541,27 +545,39 @@ def generate(
if show_tqdm:
pbar.close()
self.logger.warn("regenerate in order to ensure non-empty")
del_all(attentions)
del_all(hiddens)
del (
start_idx,
end_idx,
finish,
temperature,
attention_mask_cache,
past_key_values,
idx_next,
inputs_ids_tmp,
)
new_gen = self.generate(
emb,
inputs_ids,
old_temperature,
eos_token,
attention_mask,
max_new_token,
min_new_token,
logits_warpers,
logits_processors,
infer_text,
return_attn,
return_hidden,
stream,
show_tqdm,
ensure_non_empty,
context,
emb=emb,
inputs_ids=inputs_ids,
old_temperature=old_temperature,
eos_token=eos_token,
attention_mask=attention_mask,
max_new_token=max_new_token,
min_new_token=min_new_token,
logits_warpers=logits_warpers,
logits_processors=logits_processors,
infer_text=infer_text,
return_attn=return_attn,
return_hidden=return_hidden,
stream=stream,
show_tqdm=show_tqdm,
ensure_non_empty=ensure_non_empty,
context=context,
)
for result in new_gen:
yield result
return
return

del inputs_ids
inputs_ids = inputs_ids_tmp
Expand Down
1 change: 1 addition & 0 deletions modules/api/impl/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ async def openai_speech_api(
spliter_threshold=spliter_threshold,
eos=eos,
seed=seed,
stream=stream,
)
adjust_config = AdjustConfig(speaking_rate=speed)
enhancer_config = EnhancerConfig(
Expand Down
23 changes: 15 additions & 8 deletions modules/api/impl/ssml_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SSMLRequest(BaseModel):
enhancer: EnhancerConfig = EnhancerConfig()
adjuster: AdjustConfig = AdjustConfig()

stream: bool = False


async def synthesize_ssml_api(
request: SSMLRequest = Body(
Expand All @@ -35,6 +37,7 @@ async def synthesize_ssml_api(
format = request.format.lower()
batch_size = request.batch_size
eos = request.eos
stream = request.stream
spliter_thr = request.spliter_thr
enhancer = request.enhancer
adjuster = request.adjuster
Expand All @@ -58,9 +61,7 @@ async def synthesize_ssml_api(
)

infer_config = InferConfig(
batch_size=batch_size,
spliter_threshold=spliter_thr,
eos=eos,
batch_size=batch_size, spliter_threshold=spliter_thr, eos=eos, stream=stream
)
adjust_config = adjuster
enhancer_config = enhancer
Expand All @@ -72,12 +73,18 @@ async def synthesize_ssml_api(
enhancer_config=enhancer_config,
)

buffer = handler.enqueue_to_buffer(format=request.format)

mime_type = f"audio/{format}"
media_type = f"audio/{format}"
if format == AudioFormat.mp3:
mime_type = "audio/mpeg"
return StreamingResponse(buffer, media_type=mime_type)
media_type = "audio/mpeg"

if stream:
gen = handler.enqueue_to_stream_with_request(
request=request, format=AudioFormat(format)
)
return StreamingResponse(gen, media_type=media_type)
else:
buffer = handler.enqueue_to_buffer(format=AudioFormat(format))
return StreamingResponse(buffer, media_type=media_type)

except Exception as e:
import logging
Expand Down
10 changes: 3 additions & 7 deletions modules/api/impl/tts_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()):
prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
eos = params.eos or ""
stream = params.stream

batch_size = int(params.bs)
threshold = int(params.thr)
Expand All @@ -120,6 +121,7 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()):
spliter_threshold=threshold,
eos=eos,
seed=seed,
stream=stream,
)
adjust_config = AdjustConfig(
pitch=params.pitch,
Expand All @@ -143,13 +145,7 @@ async def synthesize_tts(request: Request, params: TTSParams = Depends()):
if params.format == "mp3":
media_type = "audio/mpeg"

if params.stream:
if infer_config.batch_size != 1:
# 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略
logger.warning(
f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1"
)

if stream:
gen = handler.enqueue_to_stream_with_request(
request=request, format=AudioFormat(params.format)
)
Expand Down
2 changes: 2 additions & 0 deletions modules/api/impl/xtts_v2_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ async def tts_to_audio(request: SynthesisRequest):
spliter_threshold=XTTSV2.spliter_threshold,
eos=XTTSV2.eos,
seed=XTTSV2.infer_seed,
stream=False,
)
adjust_config = AdjustConfig(
speed_rate=XTTSV2.speed,
Expand Down Expand Up @@ -164,6 +165,7 @@ async def tts_stream(
spliter_threshold=XTTSV2.spliter_threshold,
eos=XTTSV2.eos,
seed=XTTSV2.infer_seed,
stream=True,
)
adjust_config = AdjustConfig(
speed_rate=XTTSV2.speed,
Expand Down
3 changes: 1 addition & 2 deletions modules/core/handler/AudioHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def enqueue_to_stream(self, format: AudioFormat) -> Generator[bytes, None, None]
async def enqueue_to_stream_with_request(
self, request: Request, format: AudioFormat
) -> AsyncGenerator[bytes, None]:
buffer_gen = self.enqueue_to_stream(format=AudioFormat(format))
for chunk in buffer_gen:
for chunk in self.enqueue_to_stream(format=AudioFormat(format)):
disconnected = await request.is_disconnected()
if disconnected:
ChatTTSInfer.interrupt()
Expand Down
22 changes: 18 additions & 4 deletions modules/core/models/tts/ChatTtsModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def _gen():
prompt2=prompt2,
prefix=prefix,
)
audio_arr: list[NP_AUDIO] = [(sr, data[0]) for data in results]
audio_arr: list[NP_AUDIO] = [
# NOTE: data[0] 的意思是 立体声 => mono audio
(sr, np.empty(0)) if data is None else (sr, data[0])
for data in results
]

self.set_cache(segments=segments, context=context, value=audio_arr)
return audio_arr
Expand All @@ -157,16 +161,26 @@ def _gen() -> Generator[list[NP_AUDIO], None, None]:
prefix=prefix,
stream_chunk_size=chunk_size,
):
audio_arr: list[NP_AUDIO] = [(sr, data[0]) for data in results]
results = [
(
np.empty(0)
# None 应该是生成失败, size === 0 是生成结束
if data is None or data.size == 0
# NOTE: data[0] 的意思是 立体声 => mono audio
else data[0]
)
for data in results
]
audio_arr: list[NP_AUDIO] = [(sr, data) for data in results]
yield audio_arr

if audio_arr_buff is None:
audio_arr_buff = audio_arr
else:
for i, data in enumerate(results):
sr1, before = audio_arr_buff[i]
buff = np.concatenate([before, data[0]], axis=0)
buff = np.concatenate([before, data], axis=0)
audio_arr_buff[i] = (sr1, buff)
self.set_cache(segments=segments, context=context, value=audio_arr_buff)

return _gen(infer)
return _gen()
10 changes: 6 additions & 4 deletions modules/core/models/zoo/ChatTTSInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def generate_audio_stream(
stream_chunk_size=96,
use_decoder=True,
) -> Generator[list[np.ndarray], None, None]:
gen = self._generate_audio(
gen: Generator[list[np.ndarray], None, None] = self._generate_audio(
text=text,
spk_emb=spk_emb,
top_P=top_P,
Expand All @@ -354,9 +354,11 @@ def generate_audio_stream(

def _generator():
with disable_tqdm(enabled=config.runtime_env_vars.off_tqdm):
for audio in gen:
if audio is not None:
yield audio
for audio_arr in gen:
# 如果为空就用空 array 填充
# NOTE: 因为长度不一定,所以某个位置可能是 None
audio_arr = [np.empty(0) if i is None else i for i in audio_arr]
yield audio_arr

return _generator()

Expand Down
11 changes: 6 additions & 5 deletions modules/core/pipeline/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _process_array(
lambd = enhancer_config.lambd
tau = enhancer_config.tau

sample_rate, audio_data = audio
audio_data, sample_rate = apply_audio_enhance_full(
audio_data=audio_data,
sr=sample_rate,
Expand All @@ -36,22 +37,22 @@ def _process_array(
tau=tau,
)

return audio_data, sample_rate
return sample_rate, audio_data


class AdjusterProcessor(AudioProcessor):
def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUDIO:
sr, audio_data = audio
sample_rate, audio_data = audio
adjust_config = context.adjust_config

audio_data = audio_utils.apply_prosody_to_audio_data(
audio_data=audio_data,
rate=adjust_config.speed_rate,
pitch=adjust_config.pitch,
volume=adjust_config.volume_gain_db,
sr=sr,
sr=sample_rate,
)
return sr, audio_data
return sample_rate, audio_data


class AudioNormalizer(AudioProcessor):
Expand All @@ -68,7 +69,7 @@ def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUD

class ChatTtsTNProcessor(TextProcessor):
def process(self, segment: TTSSegment, context: TTSPipelineContext) -> TTSSegment:
segment.text = ChatTtsTN.normalize(segment.text, context.tts_config)
segment.text = ChatTtsTN.normalize(text=segment.text, config=context.tn_config)
return segment


Expand Down
2 changes: 1 addition & 1 deletion modules/core/pipeline/generate/BatchGenerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def generate_batch_stream(self, batch: TTSBatch):
):
for audio, result in zip(batch.segments, results):
sr, data = result
if len(data) == 0:
if data.size == 0:
audio.done = True
continue
audio.data = np.concatenate([audio.data, data], axis=0)
Expand Down
24 changes: 17 additions & 7 deletions modules/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def set_model(self, model):
def create_synth(self):
chunker = TTSChunker(context=self.context)
segments = chunker.segments()
# 其实这个在 chunker 之前调用好点...但是有副作用所以放在后面
segments = [self.process_text(seg) for seg in segments]

synth = BatchSynth(
input_segments=segments, context=self.context, model=self.model
)
Expand All @@ -41,20 +44,27 @@ def generate(self) -> NP_AUDIO:
synth = self.create_synth()
synth.start_generate()
synth.wait_done()
return synth.sr(), synth.read()
audio = synth.sr(), synth.read()
return self.process_np_audio(audio)

def generate_stream(self) -> Generator[NP_AUDIO, None, None]:
synth = self.create_synth()
synth.start_generate()
while not synth.is_done():
data = synth.read()
if data.size > 0:
yield synth.sr(), data
audio = synth.sr(), data
yield self.process_np_audio(audio)
# TODO: replace with threading.Event
sleep(0.1)
data = synth.read()
if data.size > 0:
yield synth.sr(), data
audio = synth.sr(), data
yield self.process_np_audio(audio)

def process_np_audio(self, audio: NP_AUDIO) -> NP_AUDIO:
audio = self.process_audio(audio)
return self.ensure_audio_type(audio, "ndarray")

def ensure_audio_type(
self, audio: AUDIO, output_type: Literal["ndarray", "segment"]
Expand All @@ -80,14 +90,14 @@ def _to_ndarray(self, audio: AUDIO) -> NP_AUDIO:
return sr, audio
return audio

def process_text(self, text: TTSSegment):
def process_text(self, seg: TTSSegment):
for module in self.modules:
if isinstance(module, TextProcessor):
text = module.process(text)
return text
seg = module.process(segment=seg, context=self.context)
return seg

def process_audio(self, audio: AUDIO):
for module in self.modules:
if isinstance(module, AudioProcessor):
audio = module.process(audio)
audio = module.process(audio=audio, context=self.context)
return audio
5 changes: 3 additions & 2 deletions modules/core/pipeline/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ def process(self, segment: TTSSegment, context: TTSPipelineContext) -> TTSSegmen

class AudioProcessor:
def process(self, audio: AUDIO, context: TTSPipelineContext) -> AUDIO:
if isinstance(audio, np.ndarray):
if isinstance(audio, tuple):
return self._process_array(audio, context)
elif isinstance(audio, AudioSegment):
return self._process_segment(audio, context)
else:
raise ValueError("Unsupported audio type")

def _process_array(self, audio: NP_AUDIO, context: TTSPipelineContext) -> NP_AUDIO:
segment = audio_utils.ndarray_to_segment(audio)
sr, data = audio
segment = audio_utils.ndarray_to_segment(ndarray=data, frame_rate=sr)
processed_segment = self._process_segment(segment, context)
return audio_utils.audiosegment_to_librosawav(processed_segment)

Expand Down
Loading

0 comments on commit 3095d4c

Please sign in to comment.