diff --git a/modules/speaker.py b/modules/speaker.py index a0dc770..ec20816 100644 --- a/modules/speaker.py +++ b/modules/speaker.py @@ -29,6 +29,12 @@ def from_tensor(tensor): speaker.emb = tensor return speaker + @staticmethod + def from_seed(seed: int): + speaker = Speaker(seed_or_tensor=seed) + speaker.emb = create_speaker_from_seed(seed) + return speaker + def __init__( self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe="" ): diff --git a/modules/webui/speaker/speaker_creator.py b/modules/webui/speaker/speaker_creator.py index 3ad069e..82810e2 100644 --- a/modules/webui/speaker/speaker_creator.py +++ b/modules/webui/speaker/speaker_creator.py @@ -62,11 +62,10 @@ def create_spk_from_seed( gender: str, desc: str, ): - chat_tts = load_chat_tts() - with SeedContext(seed, True): - emb = chat_tts.sample_random_speaker() - spk = Speaker(seed_or_tensor=-2, name=name, gender=gender, describe=desc) - spk.emb = emb + spk = Speaker.from_seed(seed) + spk.name = name + spk.gender = gender + spk.describe = desc with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file: torch.save(spk, tmp_file) @@ -82,7 +81,8 @@ def test_spk_voice( text: str, progress=gr.Progress(track_tqdm=True), ): - return tts_generate(spk=seed, text=text, progress=progress) + spk = Speaker.from_seed(seed) + return tts_generate(spk=spk, text=text, progress=progress) def random_speaker(): diff --git a/modules/webui/webui_utils.py b/modules/webui/webui_utils.py index 3ef3ac1..15056b0 100644 --- a/modules/webui/webui_utils.py +++ b/modules/webui/webui_utils.py @@ -208,6 +208,9 @@ def tts_generate( infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64) infer_seed = int(infer_seed) + if isinstance(spk, int): + spk = Speaker.from_seed(spk) + if spk_file: try: spk: Speaker = Speaker.from_file(spk_file)