From 15f1e7ffca80cb551316affae546ea72e8cccb34 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Wed, 5 Apr 2023 11:13:17 +0900 Subject: [PATCH] fix: fix issues when loading legacy checkpoint and fix pre-hubert n_jobs (#236) --- .../config_templates/quickvc.json | 2 +- .../so-vits-svc-4.0v1-legacy.json | 2 +- .../config_templates/so-vits-svc-4.0v1.json | 2 +- .../preprocessing/preprocess_hubert_f0.py | 22 ++++--- src/so_vits_svc_fork/utils.py | 63 +++++++++++-------- 5 files changed, 54 insertions(+), 37 deletions(-) diff --git a/src/so_vits_svc_fork/preprocessing/config_templates/quickvc.json b/src/so_vits_svc_fork/preprocessing/config_templates/quickvc.json index dbef0a69..88012218 100644 --- a/src/so_vits_svc_fork/preprocessing/config_templates/quickvc.json +++ b/src/so_vits_svc_fork/preprocessing/config_templates/quickvc.json @@ -7,7 +7,7 @@ "learning_rate": 0.0001, "betas": [0.8, 0.99], "eps": 1e-9, - "batch_size": 18, + "batch_size": 12, "fp16_run": false, "lr_decay": 0.999875, "segment_size": 10240, diff --git a/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json b/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json index 45852762..99d8f94c 100644 --- a/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json +++ b/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json @@ -7,7 +7,7 @@ "learning_rate": 0.0001, "betas": [0.8, 0.99], "eps": 1e-9, - "batch_size": 6, + "batch_size": 18, "fp16_run": false, "lr_decay": 0.999875, "segment_size": 10240, diff --git a/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json b/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json index d4c8d46f..789015dd 100644 --- a/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json +++ b/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json @@ -7,7 +7,7 @@ "learning_rate": 0.0001, "betas": [0.8, 0.99], "eps": 1e-9, - "batch_size": 18, + "batch_size": 12, "fp16_run": false, "lr_decay": 0.999875, "segment_size": 10240, diff --git a/src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py b/src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py index 2543457b..2dfb52d8 100644 --- a/src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py +++ b/src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py @@ -10,7 +10,7 @@ import torch import torchaudio from fairseq.models.hubert import HubertModel -from joblib import Parallel, delayed +from joblib import Parallel, cpu_count, delayed from tqdm import tqdm import so_vits_svc_fork.f0 @@ -22,8 +22,8 @@ from .preprocess_utils import check_hubert_min_duration LOG = getLogger(__name__) -HUBERT_MEMORY = 1600 -HUBERT_MEMORY_CREPE = 2600 +HUBERT_MEMORY = 2900 +HUBERT_MEMORY_CREPE = 3900 def _process_one( @@ -124,11 +124,17 @@ def preprocess_hubert_f0( utils.ensure_pretrained_model(".", "contentvec") hps = utils.get_hparams(config_path) if n_jobs is None: - memory = get_total_gpu_memory("free") - n_jobs = ( - memory // (HUBERT_MEMORY_CREPE if f0_method == "crepe" else HUBERT_MEMORY) - if memory is not None - else 1 + # add cpu_count() to avoid SIGKILL + memory = get_total_gpu_memory("total") + n_jobs = min( + max( + memory + // (HUBERT_MEMORY_CREPE if f0_method == "crepe" else HUBERT_MEMORY) + if memory is not None + else 1, + 1, + ), + cpu_count(), ) LOG.info(f"n_jobs automatically set to {n_jobs}, memory: {memory} MiB") diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index b0bd820d..98c9cf3c 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -163,6 +163,32 @@ def get_content( return c +def _substitute_if_same_shape(to_: dict[str, Any], from_: dict[str, Any]) -> None: + for k, v in from_.items(): + if k not in to_: + warnings.warn(f"Key {k} not found in model state dict") + elif hasattr(v, "shape"): + if not hasattr(to_[k], "shape"): + raise ValueError(f"Key {k} is not a tensor") + if to_[k].shape == v.shape: + to_[k] = v + else: + warnings.warn( + f"Shape mismatch for key {k}, {to_[k].shape} != {v.shape}" + ) + elif isinstance(v, dict): + assert isinstance(to_[k], dict) + _substitute_if_same_shape(to_[k], v) + else: + to_[k] = v + + +def safe_load(model: torch.nn.Module, state_dict: dict[str, Any]) -> None: + model_state_dict = model.state_dict() + _substitute_if_same_shape(model_state_dict, state_dict) + model.load_state_dict(model_state_dict) + + def load_checkpoint( checkpoint_path: Path | str, model: torch.nn.Module, @@ -174,37 +200,22 @@ def load_checkpoint( checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") iteration = checkpoint_dict["iteration"] learning_rate = checkpoint_dict["learning_rate"] + + # safe load module + if hasattr(model, "module"): + safe_load(model.module, checkpoint_dict["model"]) + else: + safe_load(model, checkpoint_dict["model"]) + # safe load optim if ( optimizer is not None and not skip_optimizer and checkpoint_dict["optimizer"] is not None ): - try: - optimizer.load_state_dict(checkpoint_dict["optimizer"]) - except Exception as e: - LOG.exception(e) - LOG.warning("Failed to load optimizer state") - saved_state_dict = checkpoint_dict["model"] - if hasattr(model, "module"): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - new_state_dict = {} - for k, v in state_dict.items(): - try: - new_state_dict[k] = saved_state_dict[k] - assert saved_state_dict[k].shape == v.shape, ( - saved_state_dict[k].shape, - v.shape, - ) - except Exception as e: - LOG.exception(e) - LOG.error("%s is not in the checkpoint" % k) - new_state_dict[k] = v - if hasattr(model, "module"): - model.module.load_state_dict(new_state_dict) - else: - model.load_state_dict(new_state_dict) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + safe_load(optimizer, checkpoint_dict["optimizer"]) + LOG.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration})") return model, optimizer, learning_rate, iteration