Skip to content

Commit

Permalink
fix(gui): fix devices list and fix tqdm error in gui (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
34j authored Mar 24, 2023
1 parent 495b7cb commit 59724cd
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
73 changes: 61 additions & 12 deletions src/so_vits_svc_fork/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import sounddevice as sd
import torch
from pebble import ProcessFuture, ProcessPool
from tqdm.tk import tqdm_tk

from .__main__ import init_logger
from .utils import ensure_hubert_model

GUI_DEFAULT_PRESETS_PATH = Path(__file__).parent / "default_gui_presets.json"
GUI_PRESETS_PATH = Path("./user_gui_presets.json").absolute()
Expand Down Expand Up @@ -54,17 +56,53 @@ def delete_preset(name: str) -> dict:
return load_presets()


def get_devices(update: bool = True) -> tuple[list[str], list[str]]:
def get_devices(
update: bool = True,
) -> tuple[list[str], list[str], list[int], list[int]]:
if update:
sd._terminate()
sd._initialize()
devices = sd.query_devices()
input_devices = [d["name"] for d in devices if d["max_input_channels"] > 0]
output_devices = [d["name"] for d in devices if d["max_output_channels"] > 0]
return input_devices, output_devices
hostapis = sd.query_hostapis()
for hostapi in hostapis:
for device_idx in hostapi["devices"]:
devices[device_idx]["hostapi_name"] = hostapi["name"]
input_devices = [
f"{d['name']} ({d['hostapi_name']})"
for d in devices
if d["max_input_channels"] > 0
]
output_devices = [
f"{d['name']} ({d['hostapi_name']})"
for d in devices
if d["max_output_channels"] > 0
]
input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
output_devices_indices = [
d["index"] for d in devices if d["max_output_channels"] > 0
]
return input_devices, output_devices, input_devices_indices, output_devices_indices


def main():
try:
ensure_hubert_model(tqdm_cls=tqdm_tk)
except Exception as e:
LOG.exception(e)
LOG.info("Trying tqdm.std...")
try:
ensure_hubert_model()
except Exception as e:
LOG.exception(e)
try:
ensure_hubert_model(disable=True)
except Exception as e:
LOG.exception(e)
LOG.error(
"Failed to download Hubert model. Please download it manually."
)
return

sg.theme("Dark")
model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth")))

Expand Down Expand Up @@ -292,7 +330,7 @@ def main():
sg.Combo(
key="input_device",
values=[],
size=(20, 1),
size=(60, 1),
),
],
[
Expand All @@ -301,7 +339,7 @@ def main():
sg.Combo(
key="output_device",
values=[],
size=(20, 1),
size=(60, 1),
),
],
[
Expand All @@ -310,6 +348,8 @@ def main():
key="passthrough_original",
default=False,
),
sg.Push(),
sg.Button("Refresh devices", key="refresh_devices"),
],
[
sg.Frame(
Expand Down Expand Up @@ -403,9 +443,10 @@ def main():
layout = [[column1, column2]]
# layout = [[sg.Column(layout, vertical_alignment="top", scrollable=True, expand_x=True, expand_y=True)]]
window = sg.Window(
f"{__name__.split('.')[0]}", layout, grab_anywhere=True
f"{__name__.split('.')[0]}", layout, grab_anywhere=True, finalize=True
) # , use_custom_titlebar=True)

# for n in ["input_device", "output_device"]:
# window[n].Widget.configure(justify="right")
event, values = window.read(timeout=0.01)

def update_speaker() -> None:
Expand All @@ -420,7 +461,7 @@ def update_speaker() -> None:
)

def update_devices() -> None:
input_devices, output_devices = get_devices()
input_devices, output_devices, _, _ = get_devices()
window["input_device"].update(
values=input_devices, value=values["input_device"]
)
Expand Down Expand Up @@ -465,7 +506,6 @@ def apply_preset(name: str) -> None:
break
if not event == sg.EVENT_TIMEOUT:
LOG.info(f"Event {event}, values {values}")
update_devices()
if event.endswith("_path"):
for name in window.AllKeysDict:
if str(name).endswith("_browse"):
Expand Down Expand Up @@ -493,6 +533,8 @@ def apply_preset(name: str) -> None:
elif event == "presets":
apply_preset(values["presets"])
update_speaker()
elif event == "refresh_devices":
update_devices()
elif event == "config_path":
update_speaker()
elif event == "infer":
Expand Down Expand Up @@ -541,6 +583,9 @@ def apply_preset(name: str) -> None:
if Path(values["input_path"]).exists():
pool.schedule(play_audio, args=[Path(values["input_path"])])
elif event == "start_vc":
_, _, input_device_indices, output_device_indices = get_devices(
update=False
)
from .inference_main import realtime

if future:
Expand Down Expand Up @@ -573,8 +618,12 @@ def apply_preset(name: str) -> None:
version=int(values["realtime_algorithm"][0]),
device="cuda" if values["use_gpu"] else "cpu",
block_seconds=values["block_seconds"],
input_device=values["input_device"],
output_device=values["output_device"],
input_device=input_device_indices[
window["input_device"].widget.current()
],
output_device=output_device_indices[
window["output_device"].widget.current()
],
passthrough_original=values["passthrough_original"],
),
)
Expand Down
20 changes: 14 additions & 6 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,13 @@ def f0_to_coarse(f0: torch.Tensor | float):
return f0_coarse


def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, **kwargs):
def download_file(
url: str,
filepath: Path | str,
chunk_size: int = 4 * 1024,
tqdm_cls: type = tqdm,
**kwargs,
):
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
temppath = filepath.parent / f"{filepath.name}.download"
Expand All @@ -292,7 +298,7 @@ def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, **
temppath.unlink(missing_ok=True)
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with temppath.open("wb") as f, tqdm(
with temppath.open("wb") as f, tqdm_cls(
total=total,
unit="iB",
unit_scale=True,
Expand All @@ -305,7 +311,7 @@ def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, **
temppath.rename(filepath)


def ensure_pretrained_model(folder_path: Path) -> None:
def ensure_pretrained_model(folder_path: Path, **kwargs) -> None:
model_urls = [
# "https://huggingface.co/innnky/sovits_pretrained/resolve/main/sovits4/G_0.pth",
"https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
Expand All @@ -315,17 +321,19 @@ def ensure_pretrained_model(folder_path: Path) -> None:
for model_url in model_urls:
model_path = folder_path / model_url.split("/")[-1]
if not model_path.exists():
download_file(model_url, model_path, desc=f"Downloading {model_path.name}")
download_file(
model_url, model_path, desc=f"Downloading {model_path.name}", **kwargs
)


def ensure_hubert_model() -> Path:
def ensure_hubert_model(**kwargs) -> Path:
vec_path = Path("checkpoint_best_legacy_500.pt")
vec_path.parent.mkdir(parents=True, exist_ok=True)
if not vec_path.exists():
# url = "http://obs.cstcloud.cn/share/obs/sankagenkeshi/checkpoint_best_legacy_500.pt"
# url = "https://huggingface.co/innnky/contentvec/resolve/main/checkpoint_best_legacy_500.pt"
url = "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/checkpoint_best_legacy_500.pt"
download_file(url, vec_path, desc="Downloading Hubert model")
download_file(url, vec_path, desc="Downloading Hubert model", **kwargs)
return vec_path


Expand Down

0 comments on commit 59724cd

Please sign in to comment.