Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support stream mode #360

Merged
merged 6 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _load(

self.check_model()

def infer(
def _infer(
self,
text,
skip_refine_text=False,
Expand All @@ -136,27 +136,26 @@ def infer(
use_decoder=True,
do_text_normalization=True,
lang=None,
stream=False,
):

assert self.check_model(use_decoder=use_decoder)

if not isinstance(text, list):
text = [text]

if do_text_normalization:
for i, t in enumerate(text):
_lang = detect_language(t) if lang is None else lang
if self.init_normalizer(_lang):
text[i] = self.normalizer[_lang](t)
if _lang == 'zh':
text[i] = apply_half2full_map(text[i])

for i, t in enumerate(text):
invalid_characters = count_invalid_characters(t)
if len(invalid_characters):
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
text[i] = apply_character_map(t)

if not skip_refine_text:
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
Expand All @@ -166,18 +165,44 @@ def infer(

text = [params_infer_code.get('prompt', '') + i for i in text]
params_infer_code.pop('prompt', '')
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)

result_gen = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder, stream=stream)
if use_decoder:
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
field = 'hiddens'
docoder_name = 'decoder'
else:
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]

wav = [self.pretrain_models['vocos'].decode(
i.cpu() if torch.backends.mps.is_available() else i
).cpu().numpy() for i in mel_spec]

return wav
field = 'ids'
docoder_name = 'dvae'
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
i.cpu() if torch.backends.mps.is_available() else i
).cpu().numpy() for i in spec]
if stream:

length = 0
for result in result_gen:
chunk_data = result[field][0]
assert len(result[field]) == 1
start_seek = length
length = len(chunk_data)
self.logger.debug(f'{start_seek=} total len: {length}, new len: {length - start_seek = }')
chunk_data = chunk_data[start_seek:]
if not len(chunk_data):
continue
self.logger.debug(f'new hidden {len(chunk_data)=}')
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in [chunk_data]]
wav = vocos_decode(mel_spec)
self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
yield wav
return
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in next(result_gen)[field]]
yield vocos_decode(mel_spec)

def infer(self, *args, **kwargs):
stream = kwargs.setdefault('stream', False)
res_gen = self._infer(*args, **kwargs)
if stream:
return res_gen
else:
return next(res_gen)

def sample_random_speaker(self, ):

Expand Down
4 changes: 3 additions & 1 deletion ChatTTS/infer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def infer_code(
temperature = 0.3,
repetition_penalty = 1.05,
max_new_token = 2048,
stream=False,
**kwargs
):

Expand Down Expand Up @@ -66,6 +67,7 @@ def infer_code(
eos_token = num_code,
max_new_token = max_new_token,
infer_text = False,
stream = stream,
**kwargs
)

Expand Down Expand Up @@ -122,4 +124,4 @@ def refine_text(
infer_text = True,
**kwargs
)
return result
return result
26 changes: 20 additions & 6 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def generate(
infer_text=False,
return_attn=False,
return_hidden=False,
stream=False,
):

with torch.no_grad():
Expand Down Expand Up @@ -264,7 +265,20 @@ def generate(
del idx_next

end_idx += (~finish).int().to(end_idx.device)

if stream:
if end_idx % 24 and not finish.all():
continue
y_inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
y_inputs_ids = [i[:, 0] for i in y_inputs_ids] if infer_text else y_inputs_ids
y_hiddens = [[]]
if return_hidden:
y_hiddens = torch.stack(hiddens, 1)
y_hiddens = [y_hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
yield {
'ids': y_inputs_ids,
'attentions': attentions,
'hiddens':y_hiddens,
}
if finish.all():
pbar.update(max_new_token-i-1)
break
Expand All @@ -277,12 +291,12 @@ def generate(
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]

if not finish.all():
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')

del finish

del finish

return {
yield {
'ids': inputs_ids,
'attentions': attentions,
'hiddens':hiddens,
}
}