Skip to content

Commit

Permalink
feat: configurable input and output devices
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed Mar 18, 2023
1 parent ac5b9de commit 7e5f366
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/so_vits_svc_fork/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def infer(
)
@click.option("-s", "--speaker", type=str, default=None, help="speaker name")
@click.option("-v", "--version", type=int, default=2, help="version")
@click.option("-i", "--input-device", type=int, default=None, help="input device")
@click.option("-o", "--output-device", type=int, default=None, help="output device")
def vc(
# paths
model_path: Path,
Expand All @@ -306,6 +308,8 @@ def vc(
crossfade_seconds: float,
block_seconds: float,
version: int,
input_device: int | str | None,
output_device: int | str | None,
device: Literal["cpu", "cuda"],
) -> None:
"""Realtime inference from microphone"""
Expand Down Expand Up @@ -343,6 +347,8 @@ def vc(
db_thresh=db_thresh,
pad_seconds=pad_seconds,
version=version,
input_device=input_device,
output_device=output_device,
device=device,
)

Expand Down
27 changes: 27 additions & 0 deletions src/so_vits_svc_fork/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ def play_audio(path: Path | str):
def main():
sg.theme("Dark")
model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth")))

devices = sd.query_devices()
input_devices = [d["name"] for d in devices if d["max_input_channels"] > 0]
output_devices = [d["name"] for d in devices if d["max_output_channels"] > 0]
devices[sd.default.device[0]]["name"]
devices[sd.default.device[1]]["name"]

layout = [
[
sg.Frame(
Expand Down Expand Up @@ -221,6 +228,24 @@ def main():
key="realtime_algorithm",
),
],
[
sg.Text("Input device"),
sg.Combo(
key="input_device",
values=input_devices,
size=(20, 1),
default_value=input_devices[0],
),
],
[
sg.Text("Output device"),
sg.Combo(
key="output_device",
values=output_devices,
size=(20, 1),
default_value=output_devices[0],
),
],
],
)
],
Expand Down Expand Up @@ -332,6 +357,8 @@ def update_combo() -> None:
version=int(values["realtime_algorithm"][0]),
device="cuda" if values["use_gpu"] else "cpu",
block_seconds=values["block_seconds"],
input_device=values["input_device"],
output_device=values["output_device"],
),
)
elif event == "stop_vc":
Expand Down
32 changes: 32 additions & 0 deletions src/so_vits_svc_fork/inference_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def realtime(
crossfade_seconds: float = 0.05,
block_seconds: float = 0.5,
version: int = 2,
input_device: int | str | None = None,
output_device: int | str | None = None,
device: Literal["cpu", "cuda"] = "cuda" if torch.cuda.is_available() else "cpu",
):
import sounddevice as sd
Expand All @@ -111,6 +113,35 @@ def realtime(
svc_model=svc_model,
)

# LOG all device info
devices = sd.query_devices()
LOG.info(f"Device: {devices}")
if isinstance(input_device, str):
input_device_candidates = [
i for i, d in enumerate(devices) if d["name"] == input_device
]
if len(input_device_candidates) == 0:
LOG.warning(f"Input device {input_device} not found, using default")
input_device = None
else:
input_device = input_device_candidates[0]
if isinstance(output_device, str):
output_device_candidates = [
i for i, d in enumerate(devices) if d["name"] == output_device
]
if len(output_device_candidates) == 0:
LOG.warning(f"Output device {output_device} not found, using default")
output_device = None
else:
output_device = output_device_candidates[0]
if input_device is None or input_device >= len(devices):
input_device = sd.default.device[0]
if output_device is None or output_device >= len(devices):
output_device = sd.default.device[1]
LOG.info(
f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}"
)

def callback(
indata: np.ndarray,
outdata: np.ndarray,
Expand Down Expand Up @@ -139,6 +170,7 @@ def callback(
).reshape(-1, 1)

with sd.Stream(
device=(input_device, output_device),
channels=1,
callback=callback,
samplerate=svc_model.target_sample,
Expand Down

0 comments on commit 7e5f366

Please sign in to comment.