Skip to content

Commit

Permalink
fix: fix torch.load and save to use file objects and weights_only and…
Browse files Browse the repository at this point in the history
… remove unidecode (#327)
  • Loading branch information
34j authored Apr 14, 2023
1 parent 296da44 commit 4aad701
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 58 deletions.
28 changes: 1 addition & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ cm-time = ">=0.1.2"
pysimplegui = ">=4.6"
pebble = ">=5.0"
torchcrepe = ">=0.0.17"
unidecode = "^1.3.6"
lightning = "^2.0.1"
fastapi = "==0.88"

Expand Down
3 changes: 2 additions & 1 deletion src/so_vits_svc_fork/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


def get_cluster_model(ckpt_path: Path | str):
checkpoint = torch.load(ckpt_path)
with Path(ckpt_path).open("rb") as f:
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
kmeans_dict = {}
for spk, ckpt in checkpoint.items():
km = KMeans(ckpt["n_features_in_"])
Expand Down
10 changes: 6 additions & 4 deletions src/so_vits_svc_fork/cluster/train_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def train_cluster(
LOG.info(f"Loading features from {input_dir}")
features = []
for path in input_dir.rglob("*.data.pt"):
features.append(
torch.load(path, weights_only=True)["content"].squeeze(0).numpy().T
)
with path.open("rb") as f:
features.append(
torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T
)
if not features:
raise ValueError(f"No features found in {input_dir}")
features = np.concatenate(features, axis=0).astype(np.float32)
Expand Down Expand Up @@ -86,4 +87,5 @@ def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]:
assert parallel_result is not None
checkpoint = dict(parallel_result)
output_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(checkpoint, output_path)
with output_path.open("wb") as f:
torch.save(checkpoint, f)
3 changes: 2 additions & 1 deletion src/so_vits_svc_fork/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(self, hps: HParams, is_validation: bool = False):
self.max_spec_len = 800

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
data = torch.load(self.datapaths[index], weights_only=True, map_location="cpu")
with Path(self.datapaths[index]).open("rb") as f:
data = torch.load(f, weights_only=True, map_location="cpu")

# cut long data randomly
spec_len = data["mel_spec"].shape[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ def preprocess_config(
spk_dict[speaker] = spk_id
spk_id += 1
paths = []
for path in tqdm(list((input_dir / speaker).glob("**/*.wav"))):
if not path.name.isascii():
LOG.warning(
f"file name {path} contains non-ascii characters. torch.save() and torch.load() may not work."
)
for path in tqdm(list((input_dir / speaker).rglob("*.wav"))):
if get_duration(filename=path) < 0.3:
LOG.warning(f"skip {path} because it is too short.")
continue
Expand Down
3 changes: 2 additions & 1 deletion src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def _process_one(
"spk": spk,
}
data = {k: v.cpu() for k, v in data.items()}
torch.save(data, data_path)
with data_path.open("wb") as f:
torch.save(data, f)


def _process_batch(filepaths: Iterable[Path], pbar_position: int, **kwargs):
Expand Down
8 changes: 0 additions & 8 deletions src/so_vits_svc_fork/preprocessing/preprocess_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import soundfile
from joblib import Parallel, delayed
from tqdm_joblib import tqdm_joblib
from unidecode import unidecode

from .preprocess_utils import check_hubert_min_duration

Expand Down Expand Up @@ -123,13 +122,6 @@ def preprocess_resample(
continue
speaker_name = in_path_relative.parts[0]
file_name = in_path_relative.with_suffix(".wav").name
new_filename = unidecode(file_name)
if new_filename != file_name:
LOG.warning(
f"Filename {file_name} contains non-ASCII characters. "
f"Replaced with {new_filename}."
)
file_name = new_filename
out_path = output_dir / speaker_name / file_name
out_path = _get_unique_filename(out_path, out_paths)
out_path.parent.mkdir(parents=True, exist_ok=True)
Expand Down
22 changes: 12 additions & 10 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def load_checkpoint(
) -> tuple[torch.nn.Module, torch.optim.Optimizer | None, float, int]:
if not Path(checkpoint_path).is_file():
raise FileNotFoundError(f"File {checkpoint_path} not found")
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
with Path(checkpoint_path).open("rb") as f:
checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True)
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]

Expand Down Expand Up @@ -260,15 +261,16 @@ def save_checkpoint(
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
with Path(checkpoint_path).open("wb") as f:
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
f,
)


def clean_checkpoints(
Expand Down

0 comments on commit 4aad701

Please sign in to comment.