-
Notifications
You must be signed in to change notification settings - Fork 269
/
run_gradio.py
32 lines (28 loc) · 1.54 KB
/
run_gradio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.interface.gradio import create_ui
import json
import torch
def main(args):
torch.manual_seed(42)
interface = create_ui(
model_config_path = args.model_config,
ckpt_path=args.ckpt_path,
pretrained_name=args.pretrained_name,
pretransform_ckpt_path=args.pretransform_ckpt_path,
model_half=args.model_half
)
interface.queue()
interface.launch(share=args.share, auth=(args.username, args.password) if args.username is not None else None)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run gradio interface')
parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
parser.add_argument('--share', action='store_true', help='Create a publicly shareable link', required=False)
parser.add_argument('--username', type=str, help='Gradio username', required=False)
parser.add_argument('--password', type=str, help='Gradio password', required=False)
parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False)
args = parser.parse_args()
main(args)