Skip to content

Commit

Permalink
fix(data_utils): allow float32 audio to be processed properly (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
BlueAmulet authored Mar 26, 2023
1 parent 5cd5e00 commit 13943b6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 16 deletions.
8 changes: 3 additions & 5 deletions src/so_vits_svc_fork/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import numpy as np
import torch
import torch.utils.data
import torchaudio

from . import utils
from .modules.mel_processing import spectrogram_torch
from .utils import load_filepaths_and_text, load_wav_to_torch
from .utils import load_filepaths_and_text

# import h5py

Expand All @@ -24,7 +25,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):

def __init__(self, audiopaths, hparams):
self.audiopaths = load_filepaths_and_text(audiopaths)
self.max_wav_value = hparams.data.max_wav_value
self.sampling_rate = hparams.data.sampling_rate
self.filter_length = hparams.data.filter_length
self.hop_length = hparams.data.hop_length
Expand All @@ -39,15 +39,13 @@ def __init__(self, audiopaths, hparams):

def get_audio(self, filename):
filename = filename.replace("\\", "/")
audio, sampling_rate = load_wav_to_torch(filename)
audio_norm, sampling_rate = torchaudio.load(filename)
if sampling_rate != self.sampling_rate:
raise ValueError(
"{} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate
)
)
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
spec_filename = filename.replace(".wav", ".spec.pt")
if os.path.exists(spec_filename):
spec = torch.load(spec_filename)
Expand Down
13 changes: 2 additions & 11 deletions src/so_vits_svc_fork/preprocess_flist_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,17 @@
import json
import os
import re
import wave
from copy import deepcopy
from logging import getLogger
from pathlib import Path
from random import shuffle

from librosa import get_duration
from tqdm import tqdm

LOG = getLogger(__name__)


def _get_wav_duration(filepath: Path):
with open(filepath, "rb") as f:
with wave.open(f) as wav_file:
n_frames = wav_file.getnframes()
framerate = wav_file.getframerate()
duration = n_frames / float(framerate)
return duration


def preprocess_config(
input_dir: Path | str,
train_list_path: Path | str,
Expand All @@ -48,7 +39,7 @@ def preprocess_config(
pattern = re.compile(r"^[\.a-zA-Z0-9_\/]+$")
if not pattern.match(path.name):
LOG.warning(f"file name {path} contains non-alphanumeric characters.")
if _get_wav_duration(path) < 0.3:
if get_duration(filename=path) < 0.3:
LOG.warning(f"skip {path} because it is too short.")
continue
paths.append(path)
Expand Down

0 comments on commit 13943b6

Please sign in to comment.