Skip to content

Commit

Permalink
feat(webui): impl. zero shot infer
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 19, 2024
1 parent 72a6f80 commit 6f4ceb9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _load(
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
gpt.from_pretrained(gpt_ckpt_path)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
Expand Down
18 changes: 16 additions & 2 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import gradio as gr

from tools.audio import float_to_int16, has_ffmpeg_installed
from tools.audio import float_to_int16, has_ffmpeg_installed, load_audio
from tools.logger import get_logger

logger = get_logger(" WebUI ")
Expand Down Expand Up @@ -114,6 +114,15 @@ def reload_chat(coef: Optional[str]) -> str:
return chat.coef


def on_upload_sample_audio(sample_audio_input: Optional[str]) -> str:
if sample_audio_input is None:
return ""
sample_audio = load_audio(sample_audio_input, 24000)
spk_smp = chat.sample_audio_speaker(sample_audio)
del sample_audio
return spk_smp


def _set_generate_buttons(generate_button, interrupt_button, is_reset=False):
return gr.update(
value=generate_button, visible=is_reset, interactive=is_reset
Expand Down Expand Up @@ -142,7 +151,7 @@ def refine_text(


def generate_audio(
text, temperature, top_P, top_K, spk_emb_text: str, stream, audio_seed_input
text, temperature, top_P, top_K, spk_emb_text: str, stream, audio_seed_input, sample_text_input, sample_audio_code_input,
):
global chat, has_interrupted

Expand All @@ -156,6 +165,11 @@ def generate_audio(
top_K=top_K,
)

if sample_text_input and sample_audio_code_input:
params_infer_code.txt_smp = sample_text_input
params_infer_code.spk_smp = sample_audio_code_input
params_infer_code.spk_emb = None

with TorchSeedContext(audio_seed_input):
wav = chat.infer(
text,
Expand Down
51 changes: 44 additions & 7 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,43 @@ def main():
gr.Markdown("- **GitHub Repo**: https://github.com/2noise/ChatTTS")
gr.Markdown("- **HuggingFace Repo**: https://huggingface.co/2Noise/ChatTTS")

text_input = gr.Textbox(
label="Input Text",
lines=4,
placeholder="Please Input Text...",
value=ex[0][0],
interactive=True,
)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Input Text",
lines=4,
max_lines=4,
placeholder="Please Input Text...",
value=ex[0][0],
interactive=True,
)
sample_text_input = gr.Textbox(
label="Sample Text",
lines=4,
max_lines=4,
placeholder="If Sample Audio and Sample Text are available, the Speaker Embedding will be disabled.",
interactive=True,
)
with gr.Column():
with gr.Tab(label="Sample Audio"):
sample_audio_input = gr.Audio(
value=None,
type="filepath",
interactive=True,
show_label=False,
waveform_options=gr.WaveformOptions(
sample_rate=24000,
),
scale=1,
)
with gr.Tab(label="Sample Audio Code"):
sample_audio_code_input = gr.Textbox(
lines=12,
max_lines=12,
show_label=False,
placeholder="Paste the Code copied before after uploading Sample Audio.",
interactive=True,
)

with gr.Row():
refine_text_checkbox = gr.Checkbox(
Expand Down Expand Up @@ -126,6 +156,11 @@ def main():
show_copy_button=True,
)

sample_audio_input.change(
fn=on_upload_sample_audio,
inputs=sample_audio_input, outputs=sample_audio_code_input,
).then(fn=lambda: gr.Info("Sampled Audio Code generated at another Tab."))

# 使用Gradio的回调功能来更新数值输入框
voice_selection.change(
fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input
Expand Down Expand Up @@ -181,6 +216,8 @@ def make_audio(autoplay, stream):
spk_emb_text,
stream_mode_checkbox,
audio_seed_input,
sample_text_input,
sample_audio_code_input,
],
outputs=audio_output,
).then(
Expand Down

0 comments on commit 6f4ceb9

Please sign in to comment.