Skip to content

Commit

Permalink
🐛 fix load speaker from seed #69
Browse files Browse the repository at this point in the history
- 增加 Speaker.from_seed
- 修复 webui from seed error
  • Loading branch information
zhzLuke96 committed Jun 24, 2024
1 parent 8dd6925 commit 304c318
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
6 changes: 6 additions & 0 deletions modules/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def from_tensor(tensor):
speaker.emb = tensor
return speaker

@staticmethod
def from_seed(seed: int):
speaker = Speaker(seed_or_tensor=seed)
speaker.emb = create_speaker_from_seed(seed)
return speaker

def __init__(
self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
):
Expand Down
12 changes: 6 additions & 6 deletions modules/webui/speaker/speaker_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ def create_spk_from_seed(
gender: str,
desc: str,
):
chat_tts = load_chat_tts()
with SeedContext(seed, True):
emb = chat_tts.sample_random_speaker()
spk = Speaker(seed_or_tensor=-2, name=name, gender=gender, describe=desc)
spk.emb = emb
spk = Speaker.from_seed(seed)
spk.name = name
spk.gender = gender
spk.describe = desc

with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
torch.save(spk, tmp_file)
Expand All @@ -82,7 +81,8 @@ def test_spk_voice(
text: str,
progress=gr.Progress(track_tqdm=True),
):
return tts_generate(spk=seed, text=text, progress=progress)
spk = Speaker.from_seed(seed)
return tts_generate(spk=spk, text=text, progress=progress)


def random_speaker():
Expand Down
3 changes: 3 additions & 0 deletions modules/webui/webui_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ def tts_generate(
infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
infer_seed = int(infer_seed)

if isinstance(spk, int):
spk = Speaker.from_seed(spk)

if spk_file:
try:
spk: Speaker = Speaker.from_file(spk_file)
Expand Down

0 comments on commit 304c318

Please sign in to comment.