Skip to content

Commit

Permalink
♿ pin resemble-enhance
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 8, 2024
1 parent 4256371 commit b8f41f9
Show file tree
Hide file tree
Showing 55 changed files with 4,547 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,4 +388,5 @@ Style 中带有 `_p` 的使用了 prompt + prefix,而不带 `_p` 的则只使

- ChatTTS: https://github.com/2noise/ChatTTS
- PaddleSpeech: https://github.com/PaddlePaddle/PaddleSpeech
- resemble-enhance: https://github.com/resemble-ai/resemble-enhance
- 默认说话人: https://github.com/2noise/ChatTTS/issues/238
6 changes: 3 additions & 3 deletions modules/Enhancer/ResembleEnhance.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from typing import List
from resemble_enhance.enhancer.enhancer import Enhancer
from resemble_enhance.enhancer.hparams import HParams
from resemble_enhance.inference import inference
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
from modules.repos_static.resemble_enhance.inference import inference

import torch

Expand Down
Empty file.
5 changes: 5 additions & 0 deletions modules/repos_static/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# repos static

## resemble_enhance

https://github.com/resemble-ai/resemble-enhance/tree/main
Empty file.
55 changes: 55 additions & 0 deletions modules/repos_static/resemble_enhance/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging

import torch
from torch import Tensor, nn

logger = logging.getLogger(__name__)


class Normalizer(nn.Module):
def __init__(self, momentum=0.01, eps=1e-9):
super().__init__()
self.momentum = momentum
self.eps = eps
self.running_mean_unsafe: Tensor
self.running_var_unsafe: Tensor
self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
self.register_buffer("running_var_unsafe", torch.full([], torch.nan))

@property
def started(self):
return not torch.isnan(self.running_mean_unsafe)

@property
def running_mean(self):
if not self.started:
return torch.zeros_like(self.running_mean_unsafe)
return self.running_mean_unsafe

@property
def running_std(self):
if not self.started:
return torch.ones_like(self.running_var_unsafe)
return (self.running_var_unsafe + self.eps).sqrt()

@torch.no_grad()
def _ema(self, a: Tensor, x: Tensor):
return (1 - self.momentum) * a + self.momentum * x

def update_(self, x):
if not self.started:
self.running_mean_unsafe = x.mean()
self.running_var_unsafe = x.var()
else:
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())

def forward(self, x: Tensor, update=True):
if self.training and update:
self.update_(x)
self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
x = (x - self.running_mean) / self.running_std
return x

def inverse(self, x: Tensor):
return x * self.running_std + self.running_mean
48 changes: 48 additions & 0 deletions modules/repos_static/resemble_enhance/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
import random

from torch.utils.data import DataLoader

from ..hparams import HParams
from .dataset import Dataset
from .utils import mix_fg_bg, rglob_audio_files

logger = logging.getLogger(__name__)


def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
paths = rglob_audio_files(hp.fg_dir)
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")

random.Random(seed).shuffle(paths)
train_paths = paths[:-val_size]
val_paths = paths[-val_size:]

train_ds = Dataset(train_paths, hp, training=True, mode=mode)
val_ds = Dataset(val_paths, hp, training=False, mode=mode)

logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")

return train_ds, val_ds


def create_dataloaders(hp: HParams, mode):
train_ds, val_ds = _create_datasets(hp=hp, mode=mode)

train_dl = DataLoader(
train_ds,
batch_size=hp.batch_size_per_gpu,
shuffle=True,
num_workers=hp.nj,
drop_last=True,
collate_fn=train_ds.collate_fn,
)
val_dl = DataLoader(
val_ds,
batch_size=1,
shuffle=False,
num_workers=hp.nj,
drop_last=False,
collate_fn=val_ds.collate_fn,
)
return train_dl, val_dl
171 changes: 171 additions & 0 deletions modules/repos_static/resemble_enhance/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import logging
import random
from pathlib import Path

import numpy as np
import torch
import torchaudio
import torchaudio.functional as AF
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset as DatasetBase

from ..hparams import HParams
from .distorter import Distorter
from .utils import rglob_audio_files

logger = logging.getLogger(__name__)


def _normalize(x):
return x / (np.abs(x).max() + 1e-7)


def _collate(batch, key, tensor=True, pad=True):
l = [d[key] for d in batch]
if l[0] is None:
return None
if tensor:
l = [torch.from_numpy(x) for x in l]
if pad:
assert tensor, "Can't pad non-tensor"
l = pad_sequence(l, batch_first=True)
return l


def praat_augment(wav, sr):
try:
import parselmouth
except ImportError:
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
# "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
# https://github.com/YannickJadoul/Parselmouth/issues/68
# note that this function may hang if the praat version is 0.4.3
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
sound = parselmouth.Sound(wav, sr)
formant_shift_ratio = random.uniform(1.1, 1.5)
pitch_range_factor = random.uniform(0.5, 2.0)
sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
wav = np.array(sound.values)[0].astype(np.float32)
return wav


class Dataset(DatasetBase):
def __init__(
self,
fg_paths: list[Path],
hp: HParams,
training=True,
max_retries=100,
silent_fg_prob=0.01,
mode=False,
):
super().__init__()

assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"

self.hp = hp
self.fg_paths = fg_paths
self.bg_paths = rglob_audio_files(hp.bg_dir)

if len(self.fg_paths) == 0:
raise ValueError(f"No foreground audio files found in {hp.fg_dir}")

if len(self.bg_paths) == 0:
raise ValueError(f"No background audio files found in {hp.bg_dir}")

logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")

self.training = training
self.max_retries = max_retries
self.silent_fg_prob = silent_fg_prob

self.mode = mode
self.distorter = Distorter(hp, training=training, mode=mode)

def _load_wav(self, path, length=None, random_crop=True):
wav, sr = torchaudio.load(path)

wav = AF.resample(
waveform=wav,
orig_freq=sr,
new_freq=self.hp.wav_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="sinc_interp_kaiser",
beta=14.769656459379492,
)

wav = wav.float().numpy()

if wav.ndim == 2:
wav = np.mean(wav, axis=0)

if length is None and self.training:
length = int(self.hp.training_seconds * self.hp.wav_rate)

if length is not None:
if random_crop:
start = random.randint(0, max(0, len(wav) - length))
wav = wav[start : start + length]
else:
wav = wav[:length]

if length is not None and len(wav) < length:
wav = np.pad(wav, (0, length - len(wav)))

wav = _normalize(wav)

return wav

def _getitem_unsafe(self, index: int):
fg_path = self.fg_paths[index]

if self.training and random.random() < self.silent_fg_prob:
fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
else:
fg_wav = self._load_wav(fg_path)
if random.random() < self.hp.praat_augment_prob and self.training:
fg_wav = praat_augment(fg_wav, self.hp.wav_rate)

if self.hp.load_fg_only:
bg_wav = None
fg_dwav = None
bg_dwav = None
else:
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
if self.training:
bg_path = random.choice(self.bg_paths)
else:
# Deterministic for validation
bg_path = self.bg_paths[index % len(self.bg_paths)]
bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)

return dict(
fg_wav=fg_wav,
bg_wav=bg_wav,
fg_dwav=fg_dwav,
bg_dwav=bg_dwav,
)

def __getitem__(self, index: int):
for i in range(self.max_retries):
try:
return self._getitem_unsafe(index)
except Exception as e:
if i == self.max_retries - 1:
raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
index = np.random.randint(0, len(self))

def __len__(self):
return len(self.fg_paths)

@staticmethod
def collate_fn(batch):
return dict(
fg_wavs=_collate(batch, "fg_wav"),
bg_wavs=_collate(batch, "bg_wav"),
fg_dwavs=_collate(batch, "fg_dwav"),
bg_dwavs=_collate(batch, "bg_dwav"),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .distorter import Distorter
Loading

0 comments on commit b8f41f9

Please sign in to comment.