Skip to content

Commit

Permalink
Revert "fix: rollback changes"
Browse files Browse the repository at this point in the history
This reverts commit 652e60a.
  • Loading branch information
34j committed Apr 3, 2023
1 parent 78c2d95 commit eea693e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 169 deletions.
168 changes: 6 additions & 162 deletions src/so_vits_svc_fork/modules/mel_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""from logging import getLogger
from logging import getLogger

import torch
import torch.utils.data
Expand Down Expand Up @@ -27,6 +27,8 @@ def spec_to_mel_torch(spec: torch.Tensor, hps: HParams) -> torch.Tensor:
sample_rate=hps.data.sampling_rate,
f_min=hps.data.mel_fmin,
f_max=hps.data.mel_fmax,
mel_scale="slaney",
norm="slaney",
).to(spec.device)(spec)


Expand All @@ -41,165 +43,7 @@ def mel_spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
f_max=hps.data.mel_fmax,
power=1.0,
window_fn=torch.hann_window,
mel_scale="slaney",
norm="slaney",
normalized=False,
).to(audio.device)(audio)"""

from logging import getLogger

import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn

LOG = getLogger(__name__)

MAX_WAV_VALUE = 32768.0


def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)


def dynamic_range_decompression_torch(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C


def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output


def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output


mel_basis = {}
hann_window = {}


def spectrogram_torch(y, hps, center=False):
if torch.min(y) < -1.0:
LOG.info("min value is ", torch.min(y))
if torch.max(y) > 1.0:
LOG.info("max value is ", torch.max(y))
n_fft = hps.data.filter_length
hop_size = hps.data.hop_length
win_size = hps.data.win_length
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)

y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)

spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec


def spec_to_mel_torch(spec, hps):
sampling_rate = hps.data.sampling_rate
n_fft = hps.data.filter_length
num_mels = hps.data.n_mel_channels
fmin = hps.data.mel_fmin
fmax = hps.data.mel_fmax
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec


def mel_spectrogram_torch(y, hps, center=False):
sampling_rate = hps.data.sampling_rate
n_fft = hps.data.filter_length
num_mels = hps.data.n_mel_channels
fmin = hps.data.mel_fmin
fmax = hps.data.mel_fmax
hop_size = hps.data.hop_length
win_size = hps.data.win_length
if torch.min(y) < -1.0:
LOG.info(f"min value is {torch.min(y)}")
if torch.max(y) > 1.0:
LOG.info(f"max value is {torch.max(y)}")

global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)

y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)

spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)

return spec
).to(audio.device)(audio)
10 changes: 4 additions & 6 deletions src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import librosa
import numpy as np
import torch
import torchaudio
from fairseq.models.hubert import HubertModel
from joblib import Parallel, delayed
from tqdm import tqdm
Expand All @@ -17,7 +16,7 @@
from so_vits_svc_fork import utils

from ..hparams import HParams
from ..modules.mel_processing import spec_to_mel_torch, spectrogram_torch
from ..modules.mel_processing import mel_spectrogram_torch, spectrogram_torch
from ..utils import get_total_gpu_memory
from .preprocess_utils import check_hubert_min_duration

Expand Down Expand Up @@ -66,9 +65,8 @@ def _process_one(
torch.cuda.empty_cache()

# Compute spectrogram
audio, sr = torchaudio.load(filepath)
spec = spectrogram_torch(audio, hps).squeeze(0)
mel_spec = spec_to_mel_torch(spec, hps)
spec = spectrogram_torch(audio, hps)
mel_spec = mel_spectrogram_torch(audio, hps)
torch.cuda.empty_cache()

# fix lengths
Expand All @@ -94,7 +92,7 @@ def _process_one(
"f0": f0,
"uv": uv,
"content": c,
"audio": audio,
"audio": audio.unsqueeze(0),
"spk": spk,
}
data = {k: v.cpu() for k, v in data.items()}
Expand Down
2 changes: 1 addition & 1 deletion src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _train_and_evaluate(
) = net_g(c, f0, uv, spec, g=g, c_lengths=lengths, spec_lengths=lengths)

y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
mel, ids_slice, hps.train.segment_size // hps.data.hop_length + 1
)
y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1), hps)
y = commons.slice_segments(
Expand Down

0 comments on commit eea693e

Please sign in to comment.