Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix torch.load and save to use file objects to allow non-ASCII characters and use weights_only and remove unidecode #327

Merged
merged 2 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 1 addition & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ cm-time = ">=0.1.2"
pysimplegui = ">=4.6"
pebble = ">=5.0"
torchcrepe = ">=0.0.17"
unidecode = "^1.3.6"
lightning = "^2.0.1"
fastapi = "==0.88"

Expand Down
3 changes: 2 additions & 1 deletion src/so_vits_svc_fork/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


def get_cluster_model(ckpt_path: Path | str):
checkpoint = torch.load(ckpt_path)
with Path(ckpt_path).open("rb") as f:
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
kmeans_dict = {}
for spk, ckpt in checkpoint.items():
km = KMeans(ckpt["n_features_in_"])
Expand Down
10 changes: 6 additions & 4 deletions src/so_vits_svc_fork/cluster/train_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def train_cluster(
LOG.info(f"Loading features from {input_dir}")
features = []
for path in input_dir.rglob("*.data.pt"):
features.append(
torch.load(path, weights_only=True)["content"].squeeze(0).numpy().T
)
with path.open("rb") as f:
features.append(
torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T
)
if not features:
raise ValueError(f"No features found in {input_dir}")
features = np.concatenate(features, axis=0).astype(np.float32)
Expand Down Expand Up @@ -86,4 +87,5 @@ def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]:
assert parallel_result is not None
checkpoint = dict(parallel_result)
output_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(checkpoint, output_path)
with output_path.open("wb") as f:
torch.save(checkpoint, f)
3 changes: 2 additions & 1 deletion src/so_vits_svc_fork/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(self, hps: HParams, is_validation: bool = False):
self.max_spec_len = 800

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
data = torch.load(self.datapaths[index], weights_only=True, map_location="cpu")
with Path(self.datapaths[index]).open("rb") as f:
data = torch.load(f, weights_only=True, map_location="cpu")

# cut long data randomly
spec_len = data["mel_spec"].shape[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ def preprocess_config(
spk_dict[speaker] = spk_id
spk_id += 1
paths = []
for path in tqdm(list((input_dir / speaker).glob("**/*.wav"))):
if not path.name.isascii():
LOG.warning(
f"file name {path} contains non-ascii characters. torch.save() and torch.load() may not work."
)
for path in tqdm(list((input_dir / speaker).rglob("*.wav"))):
if get_duration(filename=path) < 0.3:
LOG.warning(f"skip {path} because it is too short.")
continue
Expand Down
3 changes: 2 additions & 1 deletion src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def _process_one(
"spk": spk,
}
data = {k: v.cpu() for k, v in data.items()}
torch.save(data, data_path)
with data_path.open("wb") as f:
torch.save(data, f)


def _process_batch(filepaths: Iterable[Path], pbar_position: int, **kwargs):
Expand Down
8 changes: 0 additions & 8 deletions src/so_vits_svc_fork/preprocessing/preprocess_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import soundfile
from joblib import Parallel, delayed
from tqdm_joblib import tqdm_joblib
from unidecode import unidecode

from .preprocess_utils import check_hubert_min_duration

Expand Down Expand Up @@ -123,13 +122,6 @@ def preprocess_resample(
continue
speaker_name = in_path_relative.parts[0]
file_name = in_path_relative.with_suffix(".wav").name
new_filename = unidecode(file_name)
if new_filename != file_name:
LOG.warning(
f"Filename {file_name} contains non-ASCII characters. "
f"Replaced with {new_filename}."
)
file_name = new_filename
out_path = output_dir / speaker_name / file_name
out_path = _get_unique_filename(out_path, out_paths)
out_path.parent.mkdir(parents=True, exist_ok=True)
Expand Down
22 changes: 12 additions & 10 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def load_checkpoint(
) -> tuple[torch.nn.Module, torch.optim.Optimizer | None, float, int]:
if not Path(checkpoint_path).is_file():
raise FileNotFoundError(f"File {checkpoint_path} not found")
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
with Path(checkpoint_path).open("rb") as f:
checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True)
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]

Expand Down Expand Up @@ -260,15 +261,16 @@ def save_checkpoint(
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
with Path(checkpoint_path).open("wb") as f:
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
f,
)


def clean_checkpoints(
Expand Down