diff --git a/src/so_vits_svc_fork/gui.py b/src/so_vits_svc_fork/gui.py index 284c09a2..8243b439 100644 --- a/src/so_vits_svc_fork/gui.py +++ b/src/so_vits_svc_fork/gui.py @@ -9,8 +9,10 @@ import sounddevice as sd import torch from pebble import ProcessFuture, ProcessPool +from tqdm.tk import tqdm_tk from .__main__ import init_logger +from .utils import ensure_hubert_model GUI_DEFAULT_PRESETS_PATH = Path(__file__).parent / "default_gui_presets.json" GUI_PRESETS_PATH = Path("./user_gui_presets.json").absolute() @@ -54,17 +56,53 @@ def delete_preset(name: str) -> dict: return load_presets() -def get_devices(update: bool = True) -> tuple[list[str], list[str]]: +def get_devices( + update: bool = True, +) -> tuple[list[str], list[str], list[int], list[int]]: if update: sd._terminate() sd._initialize() devices = sd.query_devices() - input_devices = [d["name"] for d in devices if d["max_input_channels"] > 0] - output_devices = [d["name"] for d in devices if d["max_output_channels"] > 0] - return input_devices, output_devices + hostapis = sd.query_hostapis() + for hostapi in hostapis: + for device_idx in hostapi["devices"]: + devices[device_idx]["hostapi_name"] = hostapi["name"] + input_devices = [ + f"{d['name']} ({d['hostapi_name']})" + for d in devices + if d["max_input_channels"] > 0 + ] + output_devices = [ + f"{d['name']} ({d['hostapi_name']})" + for d in devices + if d["max_output_channels"] > 0 + ] + input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0] + output_devices_indices = [ + d["index"] for d in devices if d["max_output_channels"] > 0 + ] + return input_devices, output_devices, input_devices_indices, output_devices_indices def main(): + try: + ensure_hubert_model(tqdm_cls=tqdm_tk) + except Exception as e: + LOG.exception(e) + LOG.info("Trying tqdm.std...") + try: + ensure_hubert_model() + except Exception as e: + LOG.exception(e) + try: + ensure_hubert_model(disable=True) + except Exception as e: + LOG.exception(e) + LOG.error( + "Failed to download Hubert model. Please download it manually." + ) + return + sg.theme("Dark") model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth"))) @@ -292,7 +330,7 @@ def main(): sg.Combo( key="input_device", values=[], - size=(20, 1), + size=(60, 1), ), ], [ @@ -301,7 +339,7 @@ def main(): sg.Combo( key="output_device", values=[], - size=(20, 1), + size=(60, 1), ), ], [ @@ -310,6 +348,8 @@ def main(): key="passthrough_original", default=False, ), + sg.Push(), + sg.Button("Refresh devices", key="refresh_devices"), ], [ sg.Frame( @@ -403,9 +443,10 @@ def main(): layout = [[column1, column2]] # layout = [[sg.Column(layout, vertical_alignment="top", scrollable=True, expand_x=True, expand_y=True)]] window = sg.Window( - f"{__name__.split('.')[0]}", layout, grab_anywhere=True + f"{__name__.split('.')[0]}", layout, grab_anywhere=True, finalize=True ) # , use_custom_titlebar=True) - + # for n in ["input_device", "output_device"]: + # window[n].Widget.configure(justify="right") event, values = window.read(timeout=0.01) def update_speaker() -> None: @@ -420,7 +461,7 @@ def update_speaker() -> None: ) def update_devices() -> None: - input_devices, output_devices = get_devices() + input_devices, output_devices, _, _ = get_devices() window["input_device"].update( values=input_devices, value=values["input_device"] ) @@ -465,7 +506,6 @@ def apply_preset(name: str) -> None: break if not event == sg.EVENT_TIMEOUT: LOG.info(f"Event {event}, values {values}") - update_devices() if event.endswith("_path"): for name in window.AllKeysDict: if str(name).endswith("_browse"): @@ -493,6 +533,8 @@ def apply_preset(name: str) -> None: elif event == "presets": apply_preset(values["presets"]) update_speaker() + elif event == "refresh_devices": + update_devices() elif event == "config_path": update_speaker() elif event == "infer": @@ -541,6 +583,9 @@ def apply_preset(name: str) -> None: if Path(values["input_path"]).exists(): pool.schedule(play_audio, args=[Path(values["input_path"])]) elif event == "start_vc": + _, _, input_device_indices, output_device_indices = get_devices( + update=False + ) from .inference_main import realtime if future: @@ -573,8 +618,12 @@ def apply_preset(name: str) -> None: version=int(values["realtime_algorithm"][0]), device="cuda" if values["use_gpu"] else "cpu", block_seconds=values["block_seconds"], - input_device=values["input_device"], - output_device=values["output_device"], + input_device=input_device_indices[ + window["input_device"].widget.current() + ], + output_device=output_device_indices[ + window["output_device"].widget.current() + ], passthrough_original=values["passthrough_original"], ), ) diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index 02e63530..bb1e5c79 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -283,7 +283,13 @@ def f0_to_coarse(f0: torch.Tensor | float): return f0_coarse -def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, **kwargs): +def download_file( + url: str, + filepath: Path | str, + chunk_size: int = 4 * 1024, + tqdm_cls: type = tqdm, + **kwargs, +): filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) temppath = filepath.parent / f"{filepath.name}.download" @@ -292,7 +298,7 @@ def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, ** temppath.unlink(missing_ok=True) resp = requests.get(url, stream=True) total = int(resp.headers.get("content-length", 0)) - with temppath.open("wb") as f, tqdm( + with temppath.open("wb") as f, tqdm_cls( total=total, unit="iB", unit_scale=True, @@ -305,7 +311,7 @@ def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, ** temppath.rename(filepath) -def ensure_pretrained_model(folder_path: Path) -> None: +def ensure_pretrained_model(folder_path: Path, **kwargs) -> None: model_urls = [ # "https://huggingface.co/innnky/sovits_pretrained/resolve/main/sovits4/G_0.pth", "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth", @@ -315,17 +321,19 @@ def ensure_pretrained_model(folder_path: Path) -> None: for model_url in model_urls: model_path = folder_path / model_url.split("/")[-1] if not model_path.exists(): - download_file(model_url, model_path, desc=f"Downloading {model_path.name}") + download_file( + model_url, model_path, desc=f"Downloading {model_path.name}", **kwargs + ) -def ensure_hubert_model() -> Path: +def ensure_hubert_model(**kwargs) -> Path: vec_path = Path("checkpoint_best_legacy_500.pt") vec_path.parent.mkdir(parents=True, exist_ok=True) if not vec_path.exists(): # url = "http://obs.cstcloud.cn/share/obs/sankagenkeshi/checkpoint_best_legacy_500.pt" # url = "https://huggingface.co/innnky/contentvec/resolve/main/checkpoint_best_legacy_500.pt" url = "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/checkpoint_best_legacy_500.pt" - download_file(url, vec_path, desc="Downloading Hubert model") + download_file(url, vec_path, desc="Downloading Hubert model", **kwargs) return vec_path