-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
55 changed files
with
4,547 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# repos static | ||
|
||
## resemble_enhance | ||
|
||
https://github.com/resemble-ai/resemble-enhance/tree/main |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import logging | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Normalizer(nn.Module): | ||
def __init__(self, momentum=0.01, eps=1e-9): | ||
super().__init__() | ||
self.momentum = momentum | ||
self.eps = eps | ||
self.running_mean_unsafe: Tensor | ||
self.running_var_unsafe: Tensor | ||
self.register_buffer("running_mean_unsafe", torch.full([], torch.nan)) | ||
self.register_buffer("running_var_unsafe", torch.full([], torch.nan)) | ||
|
||
@property | ||
def started(self): | ||
return not torch.isnan(self.running_mean_unsafe) | ||
|
||
@property | ||
def running_mean(self): | ||
if not self.started: | ||
return torch.zeros_like(self.running_mean_unsafe) | ||
return self.running_mean_unsafe | ||
|
||
@property | ||
def running_std(self): | ||
if not self.started: | ||
return torch.ones_like(self.running_var_unsafe) | ||
return (self.running_var_unsafe + self.eps).sqrt() | ||
|
||
@torch.no_grad() | ||
def _ema(self, a: Tensor, x: Tensor): | ||
return (1 - self.momentum) * a + self.momentum * x | ||
|
||
def update_(self, x): | ||
if not self.started: | ||
self.running_mean_unsafe = x.mean() | ||
self.running_var_unsafe = x.var() | ||
else: | ||
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean()) | ||
self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean()) | ||
|
||
def forward(self, x: Tensor, update=True): | ||
if self.training and update: | ||
self.update_(x) | ||
self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item()) | ||
x = (x - self.running_mean) / self.running_std | ||
return x | ||
|
||
def inverse(self, x: Tensor): | ||
return x * self.running_std + self.running_mean |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import logging | ||
import random | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
from ..hparams import HParams | ||
from .dataset import Dataset | ||
from .utils import mix_fg_bg, rglob_audio_files | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _create_datasets(hp: HParams, mode, val_size=10, seed=123): | ||
paths = rglob_audio_files(hp.fg_dir) | ||
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}") | ||
|
||
random.Random(seed).shuffle(paths) | ||
train_paths = paths[:-val_size] | ||
val_paths = paths[-val_size:] | ||
|
||
train_ds = Dataset(train_paths, hp, training=True, mode=mode) | ||
val_ds = Dataset(val_paths, hp, training=False, mode=mode) | ||
|
||
logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples") | ||
|
||
return train_ds, val_ds | ||
|
||
|
||
def create_dataloaders(hp: HParams, mode): | ||
train_ds, val_ds = _create_datasets(hp=hp, mode=mode) | ||
|
||
train_dl = DataLoader( | ||
train_ds, | ||
batch_size=hp.batch_size_per_gpu, | ||
shuffle=True, | ||
num_workers=hp.nj, | ||
drop_last=True, | ||
collate_fn=train_ds.collate_fn, | ||
) | ||
val_dl = DataLoader( | ||
val_ds, | ||
batch_size=1, | ||
shuffle=False, | ||
num_workers=hp.nj, | ||
drop_last=False, | ||
collate_fn=val_ds.collate_fn, | ||
) | ||
return train_dl, val_dl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import logging | ||
import random | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch | ||
import torchaudio | ||
import torchaudio.functional as AF | ||
from torch.nn.utils.rnn import pad_sequence | ||
from torch.utils.data import Dataset as DatasetBase | ||
|
||
from ..hparams import HParams | ||
from .distorter import Distorter | ||
from .utils import rglob_audio_files | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _normalize(x): | ||
return x / (np.abs(x).max() + 1e-7) | ||
|
||
|
||
def _collate(batch, key, tensor=True, pad=True): | ||
l = [d[key] for d in batch] | ||
if l[0] is None: | ||
return None | ||
if tensor: | ||
l = [torch.from_numpy(x) for x in l] | ||
if pad: | ||
assert tensor, "Can't pad non-tensor" | ||
l = pad_sequence(l, batch_first=True) | ||
return l | ||
|
||
|
||
def praat_augment(wav, sr): | ||
try: | ||
import parselmouth | ||
except ImportError: | ||
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation") | ||
# "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540", | ||
# https://github.com/YannickJadoul/Parselmouth/issues/68 | ||
# note that this function may hang if the praat version is 0.4.3 | ||
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}" | ||
sound = parselmouth.Sound(wav, sr) | ||
formant_shift_ratio = random.uniform(1.1, 1.5) | ||
pitch_range_factor = random.uniform(0.5, 2.0) | ||
sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0) | ||
wav = np.array(sound.values)[0].astype(np.float32) | ||
return wav | ||
|
||
|
||
class Dataset(DatasetBase): | ||
def __init__( | ||
self, | ||
fg_paths: list[Path], | ||
hp: HParams, | ||
training=True, | ||
max_retries=100, | ||
silent_fg_prob=0.01, | ||
mode=False, | ||
): | ||
super().__init__() | ||
|
||
assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}" | ||
|
||
self.hp = hp | ||
self.fg_paths = fg_paths | ||
self.bg_paths = rglob_audio_files(hp.bg_dir) | ||
|
||
if len(self.fg_paths) == 0: | ||
raise ValueError(f"No foreground audio files found in {hp.fg_dir}") | ||
|
||
if len(self.bg_paths) == 0: | ||
raise ValueError(f"No background audio files found in {hp.bg_dir}") | ||
|
||
logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files") | ||
|
||
self.training = training | ||
self.max_retries = max_retries | ||
self.silent_fg_prob = silent_fg_prob | ||
|
||
self.mode = mode | ||
self.distorter = Distorter(hp, training=training, mode=mode) | ||
|
||
def _load_wav(self, path, length=None, random_crop=True): | ||
wav, sr = torchaudio.load(path) | ||
|
||
wav = AF.resample( | ||
waveform=wav, | ||
orig_freq=sr, | ||
new_freq=self.hp.wav_rate, | ||
lowpass_filter_width=64, | ||
rolloff=0.9475937167399596, | ||
resampling_method="sinc_interp_kaiser", | ||
beta=14.769656459379492, | ||
) | ||
|
||
wav = wav.float().numpy() | ||
|
||
if wav.ndim == 2: | ||
wav = np.mean(wav, axis=0) | ||
|
||
if length is None and self.training: | ||
length = int(self.hp.training_seconds * self.hp.wav_rate) | ||
|
||
if length is not None: | ||
if random_crop: | ||
start = random.randint(0, max(0, len(wav) - length)) | ||
wav = wav[start : start + length] | ||
else: | ||
wav = wav[:length] | ||
|
||
if length is not None and len(wav) < length: | ||
wav = np.pad(wav, (0, length - len(wav))) | ||
|
||
wav = _normalize(wav) | ||
|
||
return wav | ||
|
||
def _getitem_unsafe(self, index: int): | ||
fg_path = self.fg_paths[index] | ||
|
||
if self.training and random.random() < self.silent_fg_prob: | ||
fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32) | ||
else: | ||
fg_wav = self._load_wav(fg_path) | ||
if random.random() < self.hp.praat_augment_prob and self.training: | ||
fg_wav = praat_augment(fg_wav, self.hp.wav_rate) | ||
|
||
if self.hp.load_fg_only: | ||
bg_wav = None | ||
fg_dwav = None | ||
bg_dwav = None | ||
else: | ||
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32) | ||
if self.training: | ||
bg_path = random.choice(self.bg_paths) | ||
else: | ||
# Deterministic for validation | ||
bg_path = self.bg_paths[index % len(self.bg_paths)] | ||
bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training) | ||
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32) | ||
|
||
return dict( | ||
fg_wav=fg_wav, | ||
bg_wav=bg_wav, | ||
fg_dwav=fg_dwav, | ||
bg_dwav=bg_dwav, | ||
) | ||
|
||
def __getitem__(self, index: int): | ||
for i in range(self.max_retries): | ||
try: | ||
return self._getitem_unsafe(index) | ||
except Exception as e: | ||
if i == self.max_retries - 1: | ||
raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e | ||
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping") | ||
index = np.random.randint(0, len(self)) | ||
|
||
def __len__(self): | ||
return len(self.fg_paths) | ||
|
||
@staticmethod | ||
def collate_fn(batch): | ||
return dict( | ||
fg_wavs=_collate(batch, "fg_wav"), | ||
bg_wavs=_collate(batch, "bg_wav"), | ||
fg_dwavs=_collate(batch, "fg_dwav"), | ||
bg_dwavs=_collate(batch, "bg_dwav"), | ||
) |
1 change: 1 addition & 0 deletions
1
modules/repos_static/resemble_enhance/data/distorter/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .distorter import Distorter |
Oops, something went wrong.