diff --git a/src/so_vits_svc_fork/__main__.py b/src/so_vits_svc_fork/__main__.py index 4e1505a1..d3bda208 100644 --- a/src/so_vits_svc_fork/__main__.py +++ b/src/so_vits_svc_fork/__main__.py @@ -287,6 +287,8 @@ def infer( ) @click.option("-s", "--speaker", type=str, default=None, help="speaker name") @click.option("-v", "--version", type=int, default=2, help="version") +@click.option("-i", "--input-device", type=int, default=None, help="input device") +@click.option("-o", "--output-device", type=int, default=None, help="output device") def vc( # paths model_path: Path, @@ -306,6 +308,8 @@ def vc( crossfade_seconds: float, block_seconds: float, version: int, + input_device: int | str | None, + output_device: int | str | None, device: Literal["cpu", "cuda"], ) -> None: """Realtime inference from microphone""" @@ -343,6 +347,8 @@ def vc( db_thresh=db_thresh, pad_seconds=pad_seconds, version=version, + input_device=input_device, + output_device=output_device, device=device, ) diff --git a/src/so_vits_svc_fork/gui.py b/src/so_vits_svc_fork/gui.py index 16a6e4c9..3586a3a7 100644 --- a/src/so_vits_svc_fork/gui.py +++ b/src/so_vits_svc_fork/gui.py @@ -25,6 +25,13 @@ def play_audio(path: Path | str): def main(): sg.theme("Dark") model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth"))) + + devices = sd.query_devices() + input_devices = [d["name"] for d in devices if d["max_input_channels"] > 0] + output_devices = [d["name"] for d in devices if d["max_output_channels"] > 0] + devices[sd.default.device[0]]["name"] + devices[sd.default.device[1]]["name"] + layout = [ [ sg.Frame( @@ -221,6 +228,24 @@ def main(): key="realtime_algorithm", ), ], + [ + sg.Text("Input device"), + sg.Combo( + key="input_device", + values=input_devices, + size=(20, 1), + default_value=input_devices[0], + ), + ], + [ + sg.Text("Output device"), + sg.Combo( + key="output_device", + values=output_devices, + size=(20, 1), + default_value=output_devices[0], + ), + ], ], ) ], @@ -332,6 +357,8 @@ def update_combo() -> None: version=int(values["realtime_algorithm"][0]), device="cuda" if values["use_gpu"] else "cpu", block_seconds=values["block_seconds"], + input_device=values["input_device"], + output_device=values["output_device"], ), ) elif event == "stop_vc": diff --git a/src/so_vits_svc_fork/inference_main.py b/src/so_vits_svc_fork/inference_main.py index c5a10bd3..40a6f0b0 100644 --- a/src/so_vits_svc_fork/inference_main.py +++ b/src/so_vits_svc_fork/inference_main.py @@ -86,6 +86,8 @@ def realtime( crossfade_seconds: float = 0.05, block_seconds: float = 0.5, version: int = 2, + input_device: int | str | None = None, + output_device: int | str | None = None, device: Literal["cpu", "cuda"] = "cuda" if torch.cuda.is_available() else "cpu", ): import sounddevice as sd @@ -111,6 +113,35 @@ def realtime( svc_model=svc_model, ) + # LOG all device info + devices = sd.query_devices() + LOG.info(f"Device: {devices}") + if isinstance(input_device, str): + input_device_candidates = [ + i for i, d in enumerate(devices) if d["name"] == input_device + ] + if len(input_device_candidates) == 0: + LOG.warning(f"Input device {input_device} not found, using default") + input_device = None + else: + input_device = input_device_candidates[0] + if isinstance(output_device, str): + output_device_candidates = [ + i for i, d in enumerate(devices) if d["name"] == output_device + ] + if len(output_device_candidates) == 0: + LOG.warning(f"Output device {output_device} not found, using default") + output_device = None + else: + output_device = output_device_candidates[0] + if input_device is None or input_device >= len(devices): + input_device = sd.default.device[0] + if output_device is None or output_device >= len(devices): + output_device = sd.default.device[1] + LOG.info( + f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}" + ) + def callback( indata: np.ndarray, outdata: np.ndarray, @@ -139,6 +170,7 @@ def callback( ).reshape(-1, 1) with sd.Stream( + device=(input_device, output_device), channels=1, callback=callback, samplerate=svc_model.target_sample,