Skip to content

Commit

Permalink
bugfixes, seed for styletts
Browse files Browse the repository at this point in the history
  • Loading branch information
KoljaB committed Dec 14, 2024
1 parent c1f6dc0 commit 3684c12
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Let me know if you need any adjustments or additional languages!

## Updates

Latest Version: v0.4.20
Latest Version: v0.4.21

Introducing StyleTTS2 engine:

Expand Down
35 changes: 26 additions & 9 deletions RealtimeTTS/engines/style_engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from .base_engine import BaseEngine
from queue import Queue
import numpy as np
import random
import torch
import sys
import os
import gc
import time
from numba import cuda

class StyleTTSVoice:
def __init__(self,
Expand Down Expand Up @@ -56,7 +55,9 @@ def __init__(self,
beta: float = 0.7,
diffusion_steps: int = 5,
embedding_scale: float = 1.0,
cuda_reset_delay: float = 0.0): # Delay after resetting CUDA device
cuda_reset_delay: float = 0.0,
seed: int = -1,
):
"""
Initializes the StyleTTS engine with customizable parameters.
Expand Down Expand Up @@ -125,6 +126,8 @@ def __init__(self,
self.diffusion_steps = diffusion_steps
self.embedding_scale = embedding_scale
self.cuda_reset_delay = cuda_reset_delay # Store the delay parameter
self.seed = seed
self.set_seeds(self.seed)

# Add the root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), self.style_root)))
Expand All @@ -134,6 +137,18 @@ def __init__(self,
self.compute_reference_style(self.ref_audio_path)
self.post_init()

def set_seeds(self, seed = 0):
if seed == -1:
seed_value = random.randint(0, 2**32 - 1)
else:
seed_value = seed
torch.manual_seed(seed_value)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed_value)
np.random.seed(seed_value)
print(f"Seed set to {seed_value}")

def post_init(self):
self.engine_name = "styletts"

Expand Down Expand Up @@ -250,10 +265,6 @@ def load_model(self):
nltk.download('punkt', quiet=True)

self.textcleaner = TextCleaner()
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
np.random.seed(0)

# Load model config
print('Loading model config from %s' % self.model_config_path)
Expand Down Expand Up @@ -284,6 +295,7 @@ def load_model(self):
_ = [self.model[key].to(self.device) for key in self.model]

# Load model checkpoint
print('Loading model checkpoint from %s' % self.model_checkpoint_path)
params_whole = torch.load(self.model_checkpoint_path, map_location='cpu')
params = params_whole['net']
for key in self.model:
Expand Down Expand Up @@ -381,10 +393,15 @@ def inference(self, text: str,
bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int())
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)

bert_dur_2 = bert_dur
while bert_dur_2.shape[1] < 100:
bert_dur_2 = torch.cat((bert_dur_2, bert_dur), dim=1)
print(f"New Padding length bert_dur_2: {bert_dur_2.shape[1]}")

noise = torch.randn(1, 256).unsqueeze(1).to(self.device)
s_pred = self.sampler(
noise=noise,
embedding=bert_dur,
embedding=bert_dur_2,
embedding_scale=embedding_scale,
features=self.ref_s,
num_steps=diffusion_steps
Expand Down Expand Up @@ -451,4 +468,4 @@ def set_voice(self, voice: StyleTTSVoice):
model_config_path=voice.model_config_path,
model_checkpoint_path=voice.model_checkpoint_path,
ref_audio_path=voice.ref_audio_path,
)
)
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
# stream2sentence is the core of RealtimeTTS - it quickly converts streamed text into sentences for real-time synthesis
stream2sentence==0.2.9
stream2sentence==0.3.0

# azure-cognitiveservices-speech is for AzureEngine
azure-cognitiveservices-speech==1.41.1

# coqui_tts is for CoquiEngine
coqui_tts==0.25.0
coqui_tts==0.25.1

# elevenlabs is for ElevenlabsEngine
elevenlabs==1.13.3
elevenlabs==1.50.1

# gtts is for GTTSEngine
gtts==2.5.4

# openai is for OpenAIEngine
openai==1.57.0
openai==1.57.4

# pyttsx3 is for SystemEngine
pyttsx3==2.98

# edge-tts is for EdgeEngine
edge-tts==6.1.19
edge-tts==7.0.0

# pyaudio is for playing chunks over output device
pyaudio==0.2.14
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
current_version = "0.4.20"
current_version = "0.4.21"

import setuptools

Expand Down

0 comments on commit 3684c12

Please sign in to comment.