Skip to content

Commit

Permalink
🐛 fix generate function precision #120
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jul 30, 2024
1 parent e46379c commit 77e6eeb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion modules/repos_static/ChatTTS/ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def generate(
del x

# logits = logits[:, -1].float()
logits = logits.narrow(1, -1, 1).squeeze_(1).to(dtype=self.gpt.dtype)
logits = logits.narrow(1, -1, 1).squeeze_(1).float()

if not infer_text:
# logits = rearrange(logits, "b c n -> (b n) c")
Expand Down

0 comments on commit 77e6eeb

Please sign in to comment.