Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(gui): fix devices list and fix tqdm error in gui #99

Merged
merged 3 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 61 additions & 12 deletions src/so_vits_svc_fork/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import sounddevice as sd
import torch
from pebble import ProcessFuture, ProcessPool
from tqdm.tk import tqdm_tk

from .__main__ import init_logger
from .utils import ensure_hubert_model

GUI_DEFAULT_PRESETS_PATH = Path(__file__).parent / "default_gui_presets.json"
GUI_PRESETS_PATH = Path("./user_gui_presets.json").absolute()
Expand Down Expand Up @@ -54,17 +56,53 @@ def delete_preset(name: str) -> dict:
return load_presets()


def get_devices(update: bool = True) -> tuple[list[str], list[str]]:
def get_devices(
update: bool = True,
) -> tuple[list[str], list[str], list[int], list[int]]:
if update:
sd._terminate()
sd._initialize()
devices = sd.query_devices()
input_devices = [d["name"] for d in devices if d["max_input_channels"] > 0]
output_devices = [d["name"] for d in devices if d["max_output_channels"] > 0]
return input_devices, output_devices
hostapis = sd.query_hostapis()
for hostapi in hostapis:
for device_idx in hostapi["devices"]:
devices[device_idx]["hostapi_name"] = hostapi["name"]
input_devices = [
f"{d['name']} ({d['hostapi_name']})"
for d in devices
if d["max_input_channels"] > 0
]
output_devices = [
f"{d['name']} ({d['hostapi_name']})"
for d in devices
if d["max_output_channels"] > 0
]
input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
output_devices_indices = [
d["index"] for d in devices if d["max_output_channels"] > 0
]
return input_devices, output_devices, input_devices_indices, output_devices_indices


def main():
try:
ensure_hubert_model(tqdm_cls=tqdm_tk)
except Exception as e:
LOG.exception(e)
LOG.info("Trying tqdm.std...")
try:
ensure_hubert_model()
except Exception as e:
LOG.exception(e)
try:
ensure_hubert_model(disable=True)
except Exception as e:
LOG.exception(e)
LOG.error(
"Failed to download Hubert model. Please download it manually."
)
return

sg.theme("Dark")
model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth")))

Expand Down Expand Up @@ -292,7 +330,7 @@ def main():
sg.Combo(
key="input_device",
values=[],
size=(20, 1),
size=(60, 1),
),
],
[
Expand All @@ -301,7 +339,7 @@ def main():
sg.Combo(
key="output_device",
values=[],
size=(20, 1),
size=(60, 1),
),
],
[
Expand All @@ -310,6 +348,8 @@ def main():
key="passthrough_original",
default=False,
),
sg.Push(),
sg.Button("Refresh devices", key="refresh_devices"),
],
[
sg.Frame(
Expand Down Expand Up @@ -403,9 +443,10 @@ def main():
layout = [[column1, column2]]
# layout = [[sg.Column(layout, vertical_alignment="top", scrollable=True, expand_x=True, expand_y=True)]]
window = sg.Window(
f"{__name__.split('.')[0]}", layout, grab_anywhere=True
f"{__name__.split('.')[0]}", layout, grab_anywhere=True, finalize=True
) # , use_custom_titlebar=True)

# for n in ["input_device", "output_device"]:
# window[n].Widget.configure(justify="right")
event, values = window.read(timeout=0.01)

def update_speaker() -> None:
Expand All @@ -420,7 +461,7 @@ def update_speaker() -> None:
)

def update_devices() -> None:
input_devices, output_devices = get_devices()
input_devices, output_devices, _, _ = get_devices()
window["input_device"].update(
values=input_devices, value=values["input_device"]
)
Expand Down Expand Up @@ -465,7 +506,6 @@ def apply_preset(name: str) -> None:
break
if not event == sg.EVENT_TIMEOUT:
LOG.info(f"Event {event}, values {values}")
update_devices()
if event.endswith("_path"):
for name in window.AllKeysDict:
if str(name).endswith("_browse"):
Expand Down Expand Up @@ -493,6 +533,8 @@ def apply_preset(name: str) -> None:
elif event == "presets":
apply_preset(values["presets"])
update_speaker()
elif event == "refresh_devices":
update_devices()
elif event == "config_path":
update_speaker()
elif event == "infer":
Expand Down Expand Up @@ -541,6 +583,9 @@ def apply_preset(name: str) -> None:
if Path(values["input_path"]).exists():
pool.schedule(play_audio, args=[Path(values["input_path"])])
elif event == "start_vc":
_, _, input_device_indices, output_device_indices = get_devices(
update=False
)
from .inference_main import realtime

if future:
Expand Down Expand Up @@ -573,8 +618,12 @@ def apply_preset(name: str) -> None:
version=int(values["realtime_algorithm"][0]),
device="cuda" if values["use_gpu"] else "cpu",
block_seconds=values["block_seconds"],
input_device=values["input_device"],
output_device=values["output_device"],
input_device=input_device_indices[
window["input_device"].widget.current()
],
output_device=output_device_indices[
window["output_device"].widget.current()
],
passthrough_original=values["passthrough_original"],
),
)
Expand Down
20 changes: 14 additions & 6 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,13 @@ def f0_to_coarse(f0: torch.Tensor | float):
return f0_coarse


def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, **kwargs):
def download_file(
url: str,
filepath: Path | str,
chunk_size: int = 4 * 1024,
tqdm_cls: type = tqdm,
**kwargs,
):
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
temppath = filepath.parent / f"{filepath.name}.download"
Expand All @@ -292,7 +298,7 @@ def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, **
temppath.unlink(missing_ok=True)
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with temppath.open("wb") as f, tqdm(
with temppath.open("wb") as f, tqdm_cls(
total=total,
unit="iB",
unit_scale=True,
Expand All @@ -305,7 +311,7 @@ def download_file(url: str, filepath: Path | str, chunk_size: int = 4 * 1024, **
temppath.rename(filepath)


def ensure_pretrained_model(folder_path: Path) -> None:
def ensure_pretrained_model(folder_path: Path, **kwargs) -> None:
model_urls = [
# "https://huggingface.co/innnky/sovits_pretrained/resolve/main/sovits4/G_0.pth",
"https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
Expand All @@ -315,17 +321,19 @@ def ensure_pretrained_model(folder_path: Path) -> None:
for model_url in model_urls:
model_path = folder_path / model_url.split("/")[-1]
if not model_path.exists():
download_file(model_url, model_path, desc=f"Downloading {model_path.name}")
download_file(
model_url, model_path, desc=f"Downloading {model_path.name}", **kwargs
)


def ensure_hubert_model() -> Path:
def ensure_hubert_model(**kwargs) -> Path:
vec_path = Path("checkpoint_best_legacy_500.pt")
vec_path.parent.mkdir(parents=True, exist_ok=True)
if not vec_path.exists():
# url = "http://obs.cstcloud.cn/share/obs/sankagenkeshi/checkpoint_best_legacy_500.pt"
# url = "https://huggingface.co/innnky/contentvec/resolve/main/checkpoint_best_legacy_500.pt"
url = "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/checkpoint_best_legacy_500.pt"
download_file(url, vec_path, desc="Downloading Hubert model")
download_file(url, vec_path, desc="Downloading Hubert model", **kwargs)
return vec_path


Expand Down