From a41ce6320ba74cd9f289706eee2490b7c988df3a Mon Sep 17 00:00:00 2001 From: zhipeng <5310853+Ox0400@users.noreply.github.com> Date: Sat, 22 Jun 2024 14:16:04 +0800 Subject: [PATCH] feat(cmd): add multiple texts (#366) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> --- examples/cmd/run.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/cmd/run.py b/examples/cmd/run.py index b4a387f24..2294bb614 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -6,10 +6,12 @@ now_dir = os.getcwd() sys.path.append(now_dir) +import wave +import argparse + from dotenv import load_dotenv load_dotenv("sha256.env") -import wave import ChatTTS from tools.audio import unsafe_float_to_int16 @@ -26,10 +28,8 @@ def save_wav_file(wav, index): wf.writeframes(unsafe_float_to_int16(wav)) logger.info(f"Audio saved to {wav_filename}") -def main(): - # Retrieve text from command line argument - text_input = sys.argv[1] if len(sys.argv) > 1 else "" - logger.info("Text input: %s", text_input) +def main(texts: list[str]): + logger.info("Text input: %s", str(texts)) chat = ChatTTS.Chat(get_logger("ChatTTS")) logger.info("Initializing ChatTTS...") @@ -39,7 +39,7 @@ def main(): logger.error("Models load failed.") sys.exit(1) - wavs = chat.infer((text_input), use_decoder=True) + wavs = chat.infer(texts, use_decoder=True) logger.info("Inference completed. Audio generation successful.") # Save each generated wav file to a local file for index, wav in enumerate(wavs): @@ -47,5 +47,8 @@ def main(): if __name__ == "__main__": logger.info("Starting the TTS application...") - main() + parser = argparse.ArgumentParser(description='ChatTTS Command', usage="--stream hello, my name is bob.") + parser.add_argument("text", help="Original text", default='YOUR TEXT HERE', nargs='*') + args = parser.parse_args() + main(args.text) logger.info("TTS application finished.")