Skip to content

Commit

Permalink
🐳 fix docker / 兼容 py 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 12, 2024
1 parent 57262b8 commit ebb096f
Show file tree
Hide file tree
Showing 15 changed files with 139 additions and 54 deletions.
11 changes: 3 additions & 8 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@ COPY . ./
USER root

RUN sed -i s@/archive.ubuntu.com/@/mirrors.tuna.tsinghua.edu.cn/@g /etc/apt/sources.list && \
apt-get update -y && \
apt-get update -y --allow-unauthenticated --fix-missing && \
apt-get install -y software-properties-common && \
apt-get install --no-install-recommends -y ffmpeg rubberband-cli && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \
rm /etc/apt/sources.list.d/cuda.list

RUN apt-get update -y --allow-unauthenticated --fix-missing && \
add-apt-repository ppa:savoury1/ffmpeg4 -y && \
apt-get update -y && \
apt-get update && xargs -r -a packages.txt apt-get install -y && \
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \
apt-get clean

RUN pip install -r requirements.docker.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
3 changes: 2 additions & 1 deletion modules/repos_static/resemble_enhance/data/distorter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
import time
from typing import Union
import warnings

import numpy as np
Expand Down Expand Up @@ -87,7 +88,7 @@ def apply(self, wav, sr):


class Permutation(Effect):
def __init__(self, *effects, n: int | None = None):
def __init__(self, *effects, n: Union[int, None] = None):
super().__init__()
self.effects = effects
self.n = n
Expand Down
11 changes: 8 additions & 3 deletions modules/repos_static/resemble_enhance/data/distorter/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Union

import librosa
import numpy as np
Expand All @@ -16,7 +17,7 @@

@dataclass
class RandomRIR(Effect):
rir_dir: Path | None
rir_dir: Union[Path, None]
rir_rate: int = 44_000
rir_suffix: str = ".npy"
deterministic: bool = False
Expand Down Expand Up @@ -49,7 +50,9 @@ def apply(self, wav, sr):

length = len(wav)

wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
wav = librosa.resample(
wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast"
)
rir = self._sample_rir()

wav = signal.convolve(wav, rir, mode="same")
Expand All @@ -58,7 +61,9 @@ def apply(self, wav, sr):
if actlev > 0.99:
wav = (wav / actlev) * 0.98

wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
wav = librosa.resample(
wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast"
)

if abs(length - len(wav)) > 10:
_logger.warning(f"length mismatch: {length} vs {len(wav)}")
Expand Down
40 changes: 32 additions & 8 deletions modules/repos_static/resemble_enhance/data/distorter/sox.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import random
from typing import Union
import warnings
from functools import partial

Expand Down Expand Up @@ -29,7 +30,9 @@ def apply(self, wav: np.ndarray, sr: int):
chain = augment.EffectChain()
chain = self.attach(chain)
tensor = torch.from_numpy(wav)[None].float() # (1, T)
tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
tensor = chain.apply(
tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr}
)
wav = tensor.numpy()[0] # (T,)
return wav

Expand All @@ -41,7 +44,9 @@ def __init__(self, effect_name: str, *args, **kwargs):
self.kwargs = kwargs

def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
_logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
_logger.debug(
f"Attaching {self.effect_name} with {self.args} and {self.kwargs}"
)
if not hasattr(chain, self.effect_name):
raise ValueError(f"EffectChain has no attribute {self.effect_name}")
return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
Expand Down Expand Up @@ -115,21 +120,30 @@ def __call__(self) -> str:


class Concat(Generator):
def __init__(self, *parts: Generator | str):
def __init__(self, *parts: Union[Generator, str]):
self.parts = parts

def __call__(self):
return "".join([part if isinstance(part, str) else part() for part in self.parts])
return "".join(
[part if isinstance(part, str) else part() for part in self.parts]
)


class RandomLowpassDistorter(SoxEffect):
def __init__(self, low=2000, high=16000):
super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
super().__init__(
"sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high))
)


class RandomBandpassDistorter(SoxEffect):
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
super().__init__(
"sinc",
"-n",
Randint(50, 200),
partial(self._fn, low, high, min_width, max_width),
)

@staticmethod
def _fn(low, high, min_width, max_width):
Expand All @@ -139,7 +153,15 @@ def _fn(low, high, min_width, max_width):


class RandomEqualizer(SoxEffect):
def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
def __init__(
self,
low=100,
high=4000,
q_low=1,
q_high=5,
db_low: int = -30,
db_high: int = 30,
):
super().__init__(
"equalizer",
Uniform(low, high),
Expand All @@ -150,7 +172,9 @@ def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_

class RandomOverdrive(SoxEffect):
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
super().__init__(
"overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high)
)


class RandomReverb(Chain):
Expand Down
6 changes: 4 additions & 2 deletions modules/repos_static/resemble_enhance/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Callable
from typing import Callable, Union

from torch import Tensor

Expand All @@ -16,7 +16,9 @@ def rglob_audio_files(path: Path):
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))


def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
def mix_fg_bg(
fg: Tensor, bg: Tensor, alpha: Union[float, Callable[..., float]] = 0.5, eps=1e-7
):
"""
Args:
fg: (b, t)
Expand Down
3 changes: 2 additions & 1 deletion modules/repos_static/resemble_enhance/denoiser/denoiser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -154,7 +155,7 @@ def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
sep_sin = sin * cos_res + cos * sin_res
return sep_mag, sep_cos, sep_sin

def forward(self, x: Tensor, y: Tensor | None = None):
def forward(self, x: Tensor, y: Union[Tensor, None] = None):
"""
Args:
x: (b t), a mixed audio
Expand Down
11 changes: 8 additions & 3 deletions modules/repos_static/resemble_enhance/enhancer/download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from pathlib import Path
from typing import Union

import torch

Expand All @@ -12,14 +13,18 @@ def get_source_url(relpath):
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"


def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None):
if run_dir is None:
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
return Path(run_dir) / relpath


def download(run_dir: str | Path | None = None):
relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
def download(run_dir: Union[str, Path, None] = None):
relpaths = [
"hparams.yaml",
"ds/G/latest",
"ds/G/default/mp_rank_00_model_states.pt",
]
for relpath in relpaths:
path = get_target_path(relpath, run_dir=run_dir)
if path.exists():
Expand Down
7 changes: 5 additions & 2 deletions modules/repos_static/resemble_enhance/enhancer/enhancer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Union

import matplotlib.pyplot as plt
import pandas as pd
Expand Down Expand Up @@ -109,7 +110,7 @@ def to_mel(self, x: Tensor, drop_last=True):
return self.mel_fn(x)[..., :-1] # (b d t)
return self.mel_fn(x)

def _may_denoise(self, x: Tensor, y: Tensor | None = None):
def _may_denoise(self, x: Tensor, y: Union[Tensor, None] = None):
if self.hp.lcfm_training_mode == "cfm":
return self.denoiser(x, y)
return x
Expand All @@ -126,7 +127,9 @@ def configurate_(self, nfe, solver, lambd, tau):
self.lcfm.eval_tau_(tau)
self._eval_lambd = lambd

def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
def forward(
self, x: Tensor, y: Union[Tensor, None] = None, z: Union[Tensor, None] = None
):
"""
Args:
x: (b t), mix wavs (fg + bg)
Expand Down
7 changes: 4 additions & 3 deletions modules/repos_static/resemble_enhance/enhancer/hparams.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Union

from ..hparams import HParams as HParamsBase

Expand All @@ -17,7 +18,7 @@ class HParams(HParamsBase):

vocoder_extra_dim: int = 32

gan_training_start_step: int | None = 5_000
enhancer_stage1_run_dir: Path | None = None
gan_training_start_step: Union[int, None] = 5_000
enhancer_stage1_run_dir: Union[Path, None] = None

denoiser_run_dir: Path | None = None
denoiser_run_dir: Union[Path, None] = None
3 changes: 2 additions & 1 deletion modules/repos_static/resemble_enhance/enhancer/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from functools import cache
from pathlib import Path
from typing import Union

import torch

Expand All @@ -13,7 +14,7 @@


@cache
def load_enhancer(run_dir: str | Path | None, device):
def load_enhancer(run_dir: Union[str, Path, None], device):
run_dir = download(run_dir)
hp = HParams.load(run_dir)
enhancer = Enhancer(hp)
Expand Down
29 changes: 20 additions & 9 deletions modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from dataclasses import dataclass
from functools import partial
from typing import Protocol
from typing import Protocol, Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -17,8 +17,7 @@


class VelocityField(Protocol):
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
...
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ...


class Solver:
Expand All @@ -40,7 +39,9 @@ def __init__(

self._camera = None
self._mel_fn = mel_fn
self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
self._time_mapping = partial(
self.exponential_decay_mapping, n=time_mapping_divisor
)

def configurate_(self, nfe=None, method=None):
if nfe is None:
Expand All @@ -50,7 +51,9 @@ def configurate_(self, nfe=None, method=None):
method = self.method

if nfe == 1 and method in ("midpoint", "rk4"):
logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
logger.warning(
f"1 NFE is not supported for {method}, using euler method instead."
)
method = "euler"

self.nfe = nfe
Expand Down Expand Up @@ -105,7 +108,9 @@ def _maybe_camera_snap(self, *, ψt, t):
)
else:
# Spectrogram, b c t
plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
plt.imshow(
ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none"
)
ax = plt.gca()
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
camera.snap()
Expand Down Expand Up @@ -271,7 +276,7 @@ def __post_init__(self):
global_dim=self.time_emb_dim,
)

def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
def _perturb(self, ψ1: Tensor, t: Union[Tensor, None] = None):
"""
Perturb ψ1 to ψt.
"""
Expand Down Expand Up @@ -311,7 +316,7 @@ def _to_u(self, *, ψ1, ψ0: Tensor):
"""
return ψ1 - ψ0

def _to_v(self, *, ψt, x, t: float | Tensor):
def _to_v(self, *, ψt, x, t: Union[float, Tensor]):
"""
Args:
ψt: (b c t)
Expand Down Expand Up @@ -364,7 +369,13 @@ def sample(self, x, ψ0=None, t0=0.0):
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
return ψ1

def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
def forward(
self,
x: Tensor,
y: Union[Tensor, None] = None,
ψ0: Union[Tensor, None] = None,
t0=0.0,
):
if y is None:
y = self.sample(x, ψ0=ψ0, t0=t0)
else:
Expand Down
3 changes: 2 additions & 1 deletion modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Union

import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -14,7 +15,7 @@
@dataclass
class IRMAEOutput:
latent: Tensor # latent vector
decoded: Tensor | None # decoder output, include extra dim
decoded: Union[Tensor, None] # decoder output, include extra dim


class ResBlock(nn.Sequential):
Expand Down
Loading

0 comments on commit ebb096f

Please sign in to comment.