Skip to content

Commit

Permalink
Optimize tqdm display
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 6, 2024
1 parent 0e278ab commit 516eca6
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions modules/ChatTTS/ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def generate(
attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask

for i in tqdm(range(max_new_token)):
if finish.all():
continue

model_input = self.prepare_inputs_for_generation(inputs_ids,
outputs.past_key_values if i!=0 else None,
Expand Down Expand Up @@ -250,9 +252,6 @@ def generate(

end_idx = end_idx + (~finish).int()

if finish.all():
break

inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids

Expand Down

0 comments on commit 516eca6

Please sign in to comment.