Skip to content

Commit

Permalink
feat: support stream mode (#360)
Browse files Browse the repository at this point in the history
* Update core.py

* Update core.py

* Update api.py

* gpt support streaming
  • Loading branch information
Ox0400 authored Jun 19, 2024
1 parent a63e9c2 commit f0babd0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 20 deletions.
51 changes: 38 additions & 13 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _load(

self.check_model()

def infer(
def _infer(
self,
text,
skip_refine_text=False,
Expand All @@ -155,22 +155,21 @@ def infer(
use_decoder=True,
do_text_normalization=True,
lang=None,
stream=False,
do_homophone_replacement=True
):

assert self.check_model(use_decoder=use_decoder)

if not isinstance(text, list):
text = [text]

if do_text_normalization:
for i, t in enumerate(text):
_lang = detect_language(t) if lang is None else lang
if self.init_normalizer(_lang):
text[i] = self.normalizer[_lang](t)
if _lang == 'zh':
text[i] = apply_half2full_map(text[i])

for i, t in enumerate(text):
invalid_characters = count_invalid_characters(t)
if len(invalid_characters):
Expand All @@ -190,18 +189,44 @@ def infer(

text = [params_infer_code.get('prompt', '') + i for i in text]
params_infer_code.pop('prompt', '')
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)

result_gen = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder, stream=stream)
if use_decoder:
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
field = 'hiddens'
docoder_name = 'decoder'
else:
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]

wav = [self.pretrain_models['vocos'].decode(
i.cpu() if torch.backends.mps.is_available() else i
).cpu().numpy() for i in mel_spec]

return wav
field = 'ids'
docoder_name = 'dvae'
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
i.cpu() if torch.backends.mps.is_available() else i
).cpu().numpy() for i in spec]
if stream:

length = 0
for result in result_gen:
chunk_data = result[field][0]
assert len(result[field]) == 1
start_seek = length
length = len(chunk_data)
self.logger.debug(f'{start_seek=} total len: {length}, new len: {length - start_seek = }')
chunk_data = chunk_data[start_seek:]
if not len(chunk_data):
continue
self.logger.debug(f'new hidden {len(chunk_data)=}')
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in [chunk_data]]
wav = vocos_decode(mel_spec)
self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
yield wav
return
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in next(result_gen)[field]]
yield vocos_decode(mel_spec)

def infer(self, *args, **kwargs):
stream = kwargs.setdefault('stream', False)
res_gen = self._infer(*args, **kwargs)
if stream:
return res_gen
else:
return next(res_gen)

def sample_random_speaker(self, ):

Expand Down
4 changes: 3 additions & 1 deletion ChatTTS/infer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def infer_code(
temperature = 0.3,
repetition_penalty = 1.05,
max_new_token = 2048,
stream=False,
**kwargs
):

Expand Down Expand Up @@ -66,6 +67,7 @@ def infer_code(
eos_token = num_code,
max_new_token = max_new_token,
infer_text = False,
stream = stream,
**kwargs
)

Expand Down Expand Up @@ -122,4 +124,4 @@ def refine_text(
infer_text = True,
**kwargs
)
return result
return result
26 changes: 20 additions & 6 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def generate(
infer_text=False,
return_attn=False,
return_hidden=False,
stream=False,
):

with torch.no_grad():
Expand Down Expand Up @@ -264,7 +265,20 @@ def generate(
del idx_next

end_idx += (~finish).int().to(end_idx.device)

if stream:
if end_idx % 24 and not finish.all():
continue
y_inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
y_inputs_ids = [i[:, 0] for i in y_inputs_ids] if infer_text else y_inputs_ids
y_hiddens = [[]]
if return_hidden:
y_hiddens = torch.stack(hiddens, 1)
y_hiddens = [y_hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
yield {
'ids': y_inputs_ids,
'attentions': attentions,
'hiddens':y_hiddens,
}
if finish.all():
pbar.update(max_new_token-i-1)
break
Expand All @@ -277,12 +291,12 @@ def generate(
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]

if not finish.all():
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')

del finish

del finish

return {
yield {
'ids': inputs_ids,
'attentions': attentions,
'hiddens':hiddens,
}
}

0 comments on commit f0babd0

Please sign in to comment.