Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: move methods from dataloader to pre-hubert #218

Merged
merged 14 commits into from
Apr 4, 2023
216 changes: 72 additions & 144 deletions src/so_vits_svc_fork/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,157 +1,85 @@
import os
import random
from __future__ import annotations

from pathlib import Path
from random import Random
from typing import Sequence

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio

import so_vits_svc_fork.f0

from . import utils
from .modules.mel_processing import spectrogram_torch
from .utils import load_filepaths_and_text

# import h5py


# Multi speaker version
from .hparams import HParams


class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""

def __init__(self, audiopaths, hparams):
self.audiopaths = load_filepaths_and_text(audiopaths)
self.sampling_rate = hparams.data.sampling_rate
self.filter_length = hparams.data.filter_length
self.hop_length = hparams.data.hop_length
self.win_length = hparams.data.win_length
self.sampling_rate = hparams.data.sampling_rate
self.use_sr = hparams.train.use_sr
self.spec_len = hparams.train.max_speclen
self.spk_map = hparams.spk

random.seed(1234)
random.shuffle(self.audiopaths)

def get_audio(self, filename):
filename = filename.replace("\\", "/")
audio_norm, sampling_rate = torchaudio.load(filename)
if sampling_rate != self.sampling_rate:
raise ValueError(
"{} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate
)
def __init__(self, hps: HParams, is_validation: bool = False):
self.datapaths = [
Path(x).parent / (Path(x).name + ".data.pt")
for x in Path(
hps.data.validation_files if is_validation else hps.data.training_files
)
spec_filename = filename.replace(".wav", ".spec.pt")
if os.path.exists(spec_filename):
spec = torch.load(spec_filename)
else:
spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename)

spk = filename.split("/")[-2]
spk = torch.LongTensor([self.spk_map[spk]])

f0 = np.load(filename + ".f0.npy")
f0, uv = so_vits_svc_fork.f0.interpolate_f0(f0)
f0 = torch.FloatTensor(f0)
uv = torch.FloatTensor(uv)

c = torch.load(filename + ".soft.pt")
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])

lmin = min(c.size(-1), spec.size(-1))
assert abs(c.size(-1) - spec.size(-1)) < 3, (
c.size(-1),
spec.size(-1),
f0.shape,
filename,
)
assert abs(audio_norm.shape[1] - lmin * self.hop_length) < 3 * self.hop_length
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
audio_norm = audio_norm[:, : lmin * self.hop_length]
# if spec.shape[1] < 30:
# print("skip too short audio:", filename)
# return None
if spec.shape[1] > 800:
start = random.randint(0, spec.shape[1] - 800)
end = start + 790
spec, c, f0, uv = (
spec[:, start:end],
c[:, start:end],
f0[start:end],
uv[start:end],
)
audio_norm = audio_norm[:, start * self.hop_length : end * self.hop_length]

return c, f0, spec, audio_norm, spk, uv

def __getitem__(self, index):
return self.get_audio(self.audiopaths[index][0])

def __len__(self):
return len(self.audiopaths)
.read_text()
.splitlines()
]
self.hps = hps
self.random = Random(hps.train.seed)
self.random.shuffle(self.datapaths)
self.max_spec_len = 800

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
data = torch.load(self.datapaths[index], weights_only=True, map_location="cpu")

# cut long data randomly
spec_len = data["mel_spec"].shape[1]
hop_len = self.hps.data.hop_length
if spec_len > self.max_spec_len:
start = self.random.randint(0, spec_len - self.max_spec_len)
end = start + self.max_spec_len - 10
for key in data.keys():
if key == "audio":
data[key] = data[key][:, start * hop_len : end * hop_len]
elif key == "spk":
continue
else:
data[key] = data[key][..., start:end]
torch.cuda.empty_cache()
return data

def __len__(self) -> int:
return len(self.datapaths)


def _pad_stack(array: Sequence[torch.Tensor]) -> torch.Tensor:
max_idx = torch.argmax(torch.tensor([x_.shape[-1] for x_ in array]))
max_x = array[max_idx]
x_padded = [
F.pad(x_, (0, max_x.shape[-1] - x_.shape[-1]), mode="constant", value=0)
for x_ in array
]
return torch.stack(x_padded)


class TextAudioCollate:
def __call__(self, batch):
def __call__(
self, batch: Sequence[dict[str, torch.Tensor]]
) -> tuple[torch.Tensor, ...]:
batch = [b for b in batch if b is not None]

input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[0].shape[1] for x in batch]), dim=0, descending=True
batch = list(sorted(batch, key=lambda x: x["mel_spec"].shape[1], reverse=True))
lengths = torch.tensor([b["mel_spec"].shape[1] for b in batch]).long()
results = {}
for key in batch[0].keys():
if key not in ["spk"]:
results[key] = _pad_stack([b[key] for b in batch]).cpu()
else:
results[key] = torch.tensor([[b[key]] for b in batch]).cpu()

return (
results["content"],
results["f0"],
results["spec"],
results["mel_spec"],
results["audio"],
results["spk"],
lengths,
results["uv"],
)

max_c_len = max([x[0].size(1) for x in batch])
max_wav_len = max([x[3].size(1) for x in batch])

lengths = torch.LongTensor(len(batch))

c_padded = torch.FloatTensor(len(batch), batch[0][0].shape[0], max_c_len)
f0_padded = torch.FloatTensor(len(batch), max_c_len)
spec_padded = torch.FloatTensor(len(batch), batch[0][2].shape[0], max_c_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
spkids = torch.LongTensor(len(batch), 1)
uv_padded = torch.FloatTensor(len(batch), max_c_len)

c_padded.zero_()
spec_padded.zero_()
f0_padded.zero_()
wav_padded.zero_()
uv_padded.zero_()

for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]

c = row[0]
c_padded[i, :, : c.size(1)] = c
lengths[i] = c.size(1)

f0 = row[1]
f0_padded[i, : f0.size(0)] = f0

spec = row[2]
spec_padded[i, :, : spec.size(1)] = spec

wav = row[3]
wav_padded[i, :, : wav.size(1)] = wav

spkids[i, 0] = row[4]

uv = row[5]
uv_padded[i, : uv.size(0)] = uv

return c_padded, f0_padded, spec_padded, wav_padded, spkids, lengths, uv_padded
70 changes: 64 additions & 6 deletions src/so_vits_svc_fork/modules/mel_processing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,49 @@
"""from logging import getLogger

import torch
import torch.utils.data
import torchaudio

LOG = getLogger(__name__)


from ..hparams import HParams


def spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
return torchaudio.transforms.Spectrogram(
n_fft=hps.data.filter_length,
win_length=hps.data.win_length,
hop_length=hps.data.hop_length,
power=1.0,
window_fn=torch.hann_window,
normalized=False,
).to(audio.device)(audio)


def spec_to_mel_torch(spec: torch.Tensor, hps: HParams) -> torch.Tensor:
return torchaudio.transforms.MelScale(
n_mels=hps.data.n_mel_channels,
sample_rate=hps.data.sampling_rate,
f_min=hps.data.mel_fmin,
f_max=hps.data.mel_fmax,
).to(spec.device)(spec)


def mel_spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
return torchaudio.transforms.MelSpectrogram(
sample_rate=hps.data.sampling_rate,
n_fft=hps.data.filter_length,
n_mels=hps.data.n_mel_channels,
win_length=hps.data.win_length,
hop_length=hps.data.hop_length,
f_min=hps.data.mel_fmin,
f_max=hps.data.mel_fmax,
power=1.0,
window_fn=torch.hann_window,
normalized=False,
).to(audio.device)(audio)"""

from logging import getLogger

import torch
Expand Down Expand Up @@ -41,12 +87,14 @@ def spectral_de_normalize_torch(magnitudes):
hann_window = {}


def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
def spectrogram_torch(y, hps, center=False):
if torch.min(y) < -1.0:
LOG.info("min value is ", torch.min(y))
if torch.max(y) > 1.0:
LOG.info("max value is ", torch.max(y))

n_fft = hps.data.filter_length
hop_size = hps.data.hop_length
win_size = hps.data.win_length
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
Expand Down Expand Up @@ -79,7 +127,12 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
return spec


def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
def spec_to_mel_torch(spec, hps):
sampling_rate = hps.data.sampling_rate
n_fft = hps.data.filter_length
num_mels = hps.data.n_mel_channels
fmin = hps.data.mel_fmin
fmax = hps.data.mel_fmax
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
Expand All @@ -95,9 +148,14 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
return spec


def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
def mel_spectrogram_torch(y, hps, center=False):
sampling_rate = hps.data.sampling_rate
n_fft = hps.data.filter_length
num_mels = hps.data.n_mel_channels
fmin = hps.data.mel_fmin
fmax = hps.data.mel_fmax
hop_size = hps.data.hop_length
win_size = hps.data.win_length
if torch.min(y) < -1.0:
LOG.info(f"min value is {torch.min(y)}")
if torch.max(y) > 1.0:
Expand Down
Loading