From 4aad701badc1eae5195e874dec40f9ed8dd40ee6 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Fri, 14 Apr 2023 20:51:46 +0900 Subject: [PATCH] fix: fix torch.load and save to use file objects and weights_only and remove unidecode (#327) --- poetry.lock | 28 +------------------ pyproject.toml | 1 - src/so_vits_svc_fork/cluster/__init__.py | 3 +- src/so_vits_svc_fork/cluster/train_cluster.py | 10 ++++--- src/so_vits_svc_fork/dataset.py | 3 +- .../preprocessing/preprocess_flist_config.py | 6 +--- .../preprocessing/preprocess_hubert_f0.py | 3 +- .../preprocessing/preprocess_resample.py | 8 ------ src/so_vits_svc_fork/utils.py | 22 ++++++++------- 9 files changed, 26 insertions(+), 58 deletions(-) diff --git a/poetry.lock b/poetry.lock index 5bc08f6b..5e927545 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4565,7 +4565,6 @@ files = [ {file = "soundfile-0.12.1-py2.py3-none-any.whl", hash = "sha256:828a79c2e75abab5359f780c81dccd4953c45a2c4cd4f05ba3e233ddf984b882"}, {file = "soundfile-0.12.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d922be1563ce17a69582a352a86f28ed8c9f6a8bc951df63476ffc310c064bfa"}, {file = "soundfile-0.12.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:bceaab5c4febb11ea0554566784bcf4bc2e3977b53946dda2b12804b4fe524a8"}, - {file = "soundfile-0.12.1-py2.py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:2dc3685bed7187c072a46ab4ffddd38cef7de9ae5eb05c03df2ad569cf4dacbc"}, {file = "soundfile-0.12.1-py2.py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:074247b771a181859d2bc1f98b5ebf6d5153d2c397b86ee9e29ba602a8dfe2a6"}, {file = "soundfile-0.12.1-py2.py3-none-win32.whl", hash = "sha256:59dfd88c79b48f441bbf6994142a19ab1de3b9bb7c12863402c2bc621e49091a"}, {file = "soundfile-0.12.1-py2.py3-none-win_amd64.whl", hash = "sha256:0d86924c00b62552b650ddd28af426e3ff2d4dc2e9047dae5b3d8452e0a49a77"}, @@ -5023,22 +5022,18 @@ files = [ {file = "torchaudio-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5d21ebbb55e7040d418d5062b0e882f9660d68b477b38fd436fa6c92ccbb52a"}, {file = "torchaudio-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6dbcd93b29d71a2f500f36a34ea5e467f510f773da85322098e6bdd8c9dc9948"}, {file = "torchaudio-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5fdaba10ff06d098d603d9eb8d2ff541c3f3fe28ba178a78787190cec0d5187f"}, - {file = "torchaudio-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:6419199c773c5045c594ff950d5e5dbbfa6c830892ec09721d4ed8704b702bfd"}, {file = "torchaudio-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:a5c81e480e5dcdcba065af1e3e31678ac29518991f00260094d37a39e63d76e5"}, {file = "torchaudio-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e2a047675493c0aa258fec621ef40e8b01abe3d8dbc872152e4b5998418aa3c5"}, {file = "torchaudio-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:91a28e587f708a03320eddbcc4a7dd1ad7150b3d4846b6c1557d85cc89a8d06c"}, {file = "torchaudio-2.0.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ba7740d98f601218ff667598ab3d9dab5f326878374fcb52d656f4ff033b9e96"}, - {file = "torchaudio-2.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:f401b192921c8b77cc5e478ede589b256dba463f1cee91172ecb376fea45a288"}, {file = "torchaudio-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:0ef6754cf75ca5fd5117cb6243a6cf33552d67e9af0075aa6954b2c34bbf1036"}, {file = "torchaudio-2.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:022ca1baa4bb819b78343bd47b57ff6dc6f9fc19fa4ef269946aadf7e62db3c0"}, {file = "torchaudio-2.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a153ad5cdb62de8ec9fd1360a0d080bbaf39d578ae04e788db211571e675b7e0"}, {file = "torchaudio-2.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:aa7897774ab4156d0b72f7078b823ebc1371ee24c50df965447782889552367a"}, - {file = "torchaudio-2.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:48d133593cddfe0424a350b566d54065bf6fe7469654de7add2f11b3ef03c5d9"}, {file = "torchaudio-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:ac65eb067feee435debba81adfe8337fa007a06de6508c0d80261c5562b6d098"}, {file = "torchaudio-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e3c6c8f9ea9f0e2df7a0b9375b0dcf955906e38fc12fab542b72a861564af8e7"}, {file = "torchaudio-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d0cf0779a334ec1861e9fa28bceb66a633c42e8f6b3322e2e37ff9f20d0ae81"}, {file = "torchaudio-2.0.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ab7acd2b5d351a2c65e4d935bb90b9256382bed93df57ee177bdbbe31c3cc984"}, - {file = "torchaudio-2.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:77b953fd7278773269a9477315b8998ae7e5011cc4b2907e0df18162327482f1"}, {file = "torchaudio-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:c01bcea9d4c4a6616452e6cbd44d55913d8e6dee58191b925f35d46a2bf6e71b"}, ] @@ -5150,15 +5145,6 @@ category = "main" optional = false python-versions = "*" files = [ - {file = "triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38806ee9663f4b0f7cd64790e96c579374089e58f49aac4a6608121aa55e2505"}, - {file = "triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:226941c7b8595219ddef59a1fdb821e8c744289a132415ddd584facedeb475b1"}, - {file = "triton-2.0.0-1-cp36-cp36m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4c9fc8c89874bc48eb7e7b2107a9b8d2c0bf139778637be5bfccb09191685cfd"}, - {file = "triton-2.0.0-1-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d2684b6a60b9f174f447f36f933e9a45f31db96cb723723ecd2dcfd1c57b778b"}, - {file = "triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9d4978298b74fcf59a75fe71e535c092b023088933b2f1df933ec32615e4beef"}, - {file = "triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:74f118c12b437fb2ca25e1a04759173b517582fcf4c7be11913316c764213656"}, - {file = "triton-2.0.0-1-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9618815a8da1d9157514f08f855d9e9ff92e329cd81c0305003eb9ec25cc5add"}, - {file = "triton-2.0.0-1-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1aca3303629cd3136375b82cb9921727f804e47ebee27b2677fef23005c3851a"}, - {file = "triton-2.0.0-1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e3e13aa8b527c9b642e3a9defcc0fbd8ffbe1c80d8ac8c15a01692478dc64d8a"}, {file = "triton-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f05a7e64e4ca0565535e3d5d3405d7e49f9d308505bb7773d21fb26a4c008c2"}, {file = "triton-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb4b99ca3c6844066e516658541d876c28a5f6e3a852286bbc97ad57134827fd"}, {file = "triton-2.0.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47b4d70dc92fb40af553b4460492c31dc7d3a114a979ffb7a5cdedb7eb546c08"}, @@ -5294,18 +5280,6 @@ files = [ {file = "ujson-5.7.0.tar.gz", hash = "sha256:e788e5d5dcae8f6118ac9b45d0b891a0d55f7ac480eddcb7f07263f2bcf37b23"}, ] -[[package]] -name = "unidecode" -version = "1.3.6" -description = "ASCII transliterations of Unicode text" -category = "main" -optional = false -python-versions = ">=3.5" -files = [ - {file = "Unidecode-1.3.6-py3-none-any.whl", hash = "sha256:547d7c479e4f377b430dd91ac1275d593308dce0fc464fb2ab7d41f82ec653be"}, - {file = "Unidecode-1.3.6.tar.gz", hash = "sha256:fed09cf0be8cf415b391642c2a5addfc72194407caee4f98719e40ec2a72b830"}, -] - [[package]] name = "urllib3" version = "1.26.15" @@ -5699,4 +5673,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "5a458cd52055f97fa6ae6cfa5bb324cb3c9d40f7d406299fb0fe5851363de192" +content-hash = "5b33ef9ebc86cbbfbc5a0c514c774af5d719196328c6ce461b18d948a9abad21" diff --git a/pyproject.toml b/pyproject.toml index cc2dac7f..d053111d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ cm-time = ">=0.1.2" pysimplegui = ">=4.6" pebble = ">=5.0" torchcrepe = ">=0.0.17" -unidecode = "^1.3.6" lightning = "^2.0.1" fastapi = "==0.88" diff --git a/src/so_vits_svc_fork/cluster/__init__.py b/src/so_vits_svc_fork/cluster/__init__.py index a0866c62..bbe9def2 100644 --- a/src/so_vits_svc_fork/cluster/__init__.py +++ b/src/so_vits_svc_fork/cluster/__init__.py @@ -8,7 +8,8 @@ def get_cluster_model(ckpt_path: Path | str): - checkpoint = torch.load(ckpt_path) + with Path(ckpt_path).open("rb") as f: + checkpoint = torch.load(f, map_location="cpu", weights_only=True) kmeans_dict = {} for spk, ckpt in checkpoint.items(): km = KMeans(ckpt["n_features_in_"]) diff --git a/src/so_vits_svc_fork/cluster/train_cluster.py b/src/so_vits_svc_fork/cluster/train_cluster.py index 44a0c1a3..b48f87da 100644 --- a/src/so_vits_svc_fork/cluster/train_cluster.py +++ b/src/so_vits_svc_fork/cluster/train_cluster.py @@ -24,9 +24,10 @@ def train_cluster( LOG.info(f"Loading features from {input_dir}") features = [] for path in input_dir.rglob("*.data.pt"): - features.append( - torch.load(path, weights_only=True)["content"].squeeze(0).numpy().T - ) + with path.open("rb") as f: + features.append( + torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T + ) if not features: raise ValueError(f"No features found in {input_dir}") features = np.concatenate(features, axis=0).astype(np.float32) @@ -86,4 +87,5 @@ def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]: assert parallel_result is not None checkpoint = dict(parallel_result) output_path.parent.mkdir(exist_ok=True, parents=True) - torch.save(checkpoint, output_path) + with output_path.open("wb") as f: + torch.save(checkpoint, f) diff --git a/src/so_vits_svc_fork/dataset.py b/src/so_vits_svc_fork/dataset.py index 6553394f..7aed7482 100644 --- a/src/so_vits_svc_fork/dataset.py +++ b/src/so_vits_svc_fork/dataset.py @@ -28,7 +28,8 @@ def __init__(self, hps: HParams, is_validation: bool = False): self.max_spec_len = 800 def __getitem__(self, index: int) -> dict[str, torch.Tensor]: - data = torch.load(self.datapaths[index], weights_only=True, map_location="cpu") + with Path(self.datapaths[index]).open("rb") as f: + data = torch.load(f, weights_only=True, map_location="cpu") # cut long data randomly spec_len = data["mel_spec"].shape[1] diff --git a/src/so_vits_svc_fork/preprocessing/preprocess_flist_config.py b/src/so_vits_svc_fork/preprocessing/preprocess_flist_config.py index a6642bee..e6654709 100644 --- a/src/so_vits_svc_fork/preprocessing/preprocess_flist_config.py +++ b/src/so_vits_svc_fork/preprocessing/preprocess_flist_config.py @@ -36,11 +36,7 @@ def preprocess_config( spk_dict[speaker] = spk_id spk_id += 1 paths = [] - for path in tqdm(list((input_dir / speaker).glob("**/*.wav"))): - if not path.name.isascii(): - LOG.warning( - f"file name {path} contains non-ascii characters. torch.save() and torch.load() may not work." - ) + for path in tqdm(list((input_dir / speaker).rglob("*.wav"))): if get_duration(filename=path) < 0.3: LOG.warning(f"skip {path} because it is too short.") continue diff --git a/src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py b/src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py index 2315f0e2..4951922f 100644 --- a/src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py +++ b/src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py @@ -98,7 +98,8 @@ def _process_one( "spk": spk, } data = {k: v.cpu() for k, v in data.items()} - torch.save(data, data_path) + with data_path.open("wb") as f: + torch.save(data, f) def _process_batch(filepaths: Iterable[Path], pbar_position: int, **kwargs): diff --git a/src/so_vits_svc_fork/preprocessing/preprocess_resample.py b/src/so_vits_svc_fork/preprocessing/preprocess_resample.py index a6bb77cd..348c4e9a 100644 --- a/src/so_vits_svc_fork/preprocessing/preprocess_resample.py +++ b/src/so_vits_svc_fork/preprocessing/preprocess_resample.py @@ -9,7 +9,6 @@ import soundfile from joblib import Parallel, delayed from tqdm_joblib import tqdm_joblib -from unidecode import unidecode from .preprocess_utils import check_hubert_min_duration @@ -123,13 +122,6 @@ def preprocess_resample( continue speaker_name = in_path_relative.parts[0] file_name = in_path_relative.with_suffix(".wav").name - new_filename = unidecode(file_name) - if new_filename != file_name: - LOG.warning( - f"Filename {file_name} contains non-ASCII characters. " - f"Replaced with {new_filename}." - ) - file_name = new_filename out_path = output_dir / speaker_name / file_name out_path = _get_unique_filename(out_path, out_paths) out_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index b246ad64..441bec41 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -221,7 +221,8 @@ def load_checkpoint( ) -> tuple[torch.nn.Module, torch.optim.Optimizer | None, float, int]: if not Path(checkpoint_path).is_file(): raise FileNotFoundError(f"File {checkpoint_path} not found") - checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + with Path(checkpoint_path).open("rb") as f: + checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True) iteration = checkpoint_dict["iteration"] learning_rate = checkpoint_dict["learning_rate"] @@ -260,15 +261,16 @@ def save_checkpoint( state_dict = model.module.state_dict() else: state_dict = model.state_dict() - torch.save( - { - "model": state_dict, - "iteration": iteration, - "optimizer": optimizer.state_dict(), - "learning_rate": learning_rate, - }, - checkpoint_path, - ) + with Path(checkpoint_path).open("wb") as f: + torch.save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + f, + ) def clean_checkpoints(