diff --git a/docs/source/conf.py b/docs/source/conf.py index d36c22f..1331f23 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,8 +12,8 @@ author = "Steven Atkinson" # TODO update this automatically from nam.__version__! -release = "0.11" -version = "0.11.1" +release = "0.12" +version = "0.12.0" # -- General configuration diff --git a/nam/_core.py b/nam/_core.py index 74d3535..ab29dba 100644 --- a/nam/_core.py +++ b/nam/_core.py @@ -2,7 +2,7 @@ # Created Date: Saturday February 5th 2022 # Author: Steven Atkinson (steven@atkinson.mn) -from copy import deepcopy +from copy import deepcopy as _deepcopy class InitializableFromConfig(object): @@ -12,4 +12,4 @@ def init_from_config(cls, config): @classmethod def parse_config(cls, config): - return deepcopy(config) + return _deepcopy(config) diff --git a/nam/cli.py b/nam/cli.py index 513baee..e527e21 100644 --- a/nam/cli.py +++ b/nam/cli.py @@ -77,17 +77,17 @@ def removesuffix(s: str, suffix: str) -> str: _apply_extensions() -import json -from argparse import ArgumentParser -from pathlib import Path +import json as _json +from argparse import ArgumentParser as _ArgumentParser +from pathlib import Path as _Path from nam.train.full import main as _nam_full -from nam.train.gui import run as nam_gui # noqa F401 Used as an entry point -from nam.util import timestamp +from nam.train.gui import run as _nam_gui # noqa F401 Used as an entry point +from nam.util import timestamp as _timestamp def nam_full(): - parser = ArgumentParser() + parser = _ArgumentParser() parser.add_argument("data_config_path", type=str) parser.add_argument("model_config_path", type=str) parser.add_argument("learning_config_path", type=str) @@ -96,17 +96,17 @@ def nam_full(): args = parser.parse_args() - def ensure_outdir(outdir: str) -> Path: - outdir = Path(outdir, timestamp()) + def ensure_outdir(outdir: str) -> _Path: + outdir = _Path(outdir, _timestamp()) outdir.mkdir(parents=True, exist_ok=False) return outdir outdir = ensure_outdir(args.outdir) # Read with open(args.data_config_path, "r") as fp: - data_config = json.load(fp) + data_config = _json.load(fp) with open(args.model_config_path, "r") as fp: - model_config = json.load(fp) + model_config = _json.load(fp) with open(args.learning_config_path, "r") as fp: - learning_config = json.load(fp) + learning_config = _json.load(fp) _nam_full(data_config, model_config, learning_config, outdir, args.no_show) diff --git a/nam/data.py b/nam/data.py index 4c72866..2c94823 100644 --- a/nam/data.py +++ b/nam/data.py @@ -6,35 +6,43 @@ Functions and classes for working with audio data with NAM """ -import abc -import logging -from collections import namedtuple -from copy import deepcopy -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union - -import numpy as np -import torch -import wavio -from scipy.interpolate import interp1d +import abc as _abc +import logging as _logging +from collections import namedtuple as _namedtuple +from copy import deepcopy as _deepcopy +from dataclasses import dataclass as _dataclass +from enum import Enum as _Enum +from pathlib import Path as _Path +from typing import ( + Any as _Any, + Callable as _Callable, + Dict as _Dict, + Optional as _Optional, + Sequence as _Sequence, + Tuple as _Tuple, + Union as _Union, +) + +import numpy as _np +import torch as _torch +import wavio as _wavio +from scipy.interpolate import interp1d as _interp1d from torch.utils.data import Dataset as _Dataset -from tqdm import tqdm +from tqdm import tqdm as _tqdm -from ._core import InitializableFromConfig +from ._core import InitializableFromConfig as _InitializableFromConfig -logger = logging.getLogger(__name__) +logger = _logging.getLogger(__name__) _REQUIRED_CHANNELS = 1 # Mono -class Split(Enum): +class Split(_Enum): TRAIN = "train" VALIDATION = "validation" -@dataclass +@_dataclass class WavInfo: sampwidth: int rate: int @@ -69,14 +77,14 @@ def shape_actual(self): def wav_to_np( - filename: Union[str, Path], - rate: Optional[int] = None, - require_match: Optional[Union[str, Path]] = None, - required_shape: Optional[Tuple[int, ...]] = None, - required_wavinfo: Optional[WavInfo] = None, - preroll: Optional[int] = None, + filename: _Union[str, _Path], + rate: _Optional[int] = None, + require_match: _Optional[_Union[str, _Path]] = None, + required_shape: _Optional[_Tuple[int, ...]] = None, + required_wavinfo: _Optional[WavInfo] = None, + preroll: _Optional[int] = None, info: bool = False, -) -> Union[np.ndarray, Tuple[np.ndarray, WavInfo]]: +) -> _Union[_np.ndarray, _Tuple[_np.ndarray, WavInfo]]: """ :param filename: Where to load from :param rate: Expected sample rate. `None` allows for anything. @@ -89,7 +97,7 @@ def wav_to_np( :param preroll: Drop this many samples off the front :param info: If `True`, also return the WAV info of this file. """ - x_wav = wavio.read(str(filename)) + x_wav = _wavio.read(str(filename)) assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono" if rate is not None and x_wav.rate != rate: raise RuntimeError( @@ -100,7 +108,7 @@ def wav_to_np( if require_match is not None: assert required_shape is None assert required_wavinfo is None - y_wav = wavio.read(str(require_match)) + y_wav = _wavio.read(str(require_match)) required_shape = y_wav.data.shape required_wavinfo = WavInfo(y_wav.sampwidth, y_wav.rate) if required_wavinfo is not None: @@ -124,33 +132,33 @@ def wav_to_np( def wav_to_tensor( *args, info: bool = False, **kwargs -) -> Union[torch.Tensor, Tuple[torch.Tensor, WavInfo]]: +) -> _Union[_torch.Tensor, _Tuple[_torch.Tensor, WavInfo]]: out = wav_to_np(*args, info=info, **kwargs) if info: arr, info = out - return torch.Tensor(arr), info + return _torch.Tensor(arr), info else: arr = out - return torch.Tensor(arr) + return _torch.Tensor(arr) -def tensor_to_wav(x: torch.Tensor, *args, **kwargs): +def tensor_to_wav(x: _torch.Tensor, *args, **kwargs): np_to_wav(x.detach().cpu().numpy(), *args, **kwargs) def np_to_wav( - x: np.ndarray, - filename: Union[str, Path], + x: _np.ndarray, + filename: _Union[str, _Path], rate: int = 48_000, sampwidth: int = 3, scale=None, **kwargs, ): - if wavio.__version__ <= "0.0.4" and scale is None: + if _wavio.__version__ <= "0.0.4" and scale is None: scale = "none" - wavio.write( + _wavio.write( str(filename), - (np.clip(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1))).astype(np.int32), + (_np.clip(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1))).astype(_np.int32), rate, scale=scale, sampwidth=sampwidth, @@ -158,8 +166,8 @@ def np_to_wav( ) -class AbstractDataset(_Dataset, abc.ABC): - @abc.abstractmethod +class AbstractDataset(_Dataset, _abc.ABC): + @_abc.abstractmethod def __getitem__(self, idx: int): """ Get input and output audio segment for training / evaluation. @@ -168,7 +176,7 @@ def __getitem__(self, idx: int): pass -class _DelayInterpolationMethod(Enum): +class _DelayInterpolationMethod(_Enum): """ :param LINEAR: Linear interpolation :param CUBIC: Cubic spline interpolation @@ -180,22 +188,22 @@ class _DelayInterpolationMethod(Enum): def _interpolate_delay( - x: torch.Tensor, delay: float, method: _DelayInterpolationMethod -) -> np.ndarray: + x: _torch.Tensor, delay: float, method: _DelayInterpolationMethod +) -> _np.ndarray: """ NOTE: This breaks the gradient tape! """ if delay == 0.0: return x - t_in = np.arange(len(x)) - n_out = len(x) - int(np.ceil(np.abs(delay))) + t_in = _np.arange(len(x)) + n_out = len(x) - int(_np.ceil(_np.abs(delay))) if delay > 0: - t_out = np.arange(n_out) + delay + t_out = _np.arange(n_out) + delay elif delay < 0: - t_out = np.arange(len(x) - n_out, len(x)) - np.abs(delay) + t_out = _np.arange(len(x) - n_out, len(x)) - _np.abs(delay) - return torch.Tensor( - interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out) + return _torch.Tensor( + _interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out) ) @@ -242,33 +250,35 @@ def _sample_to_time(s, rate): return f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples" -class Dataset(AbstractDataset, InitializableFromConfig): +class Dataset(AbstractDataset, _InitializableFromConfig): """ Take a pair of matched audio files and serve input + output pairs. """ def __init__( self, - x: torch.Tensor, - y: torch.Tensor, + x: _torch.Tensor, + y: _torch.Tensor, nx: int, - ny: Optional[int], - start: Optional[int] = None, - stop: Optional[int] = None, - start_samples: Optional[int] = None, - stop_samples: Optional[int] = None, - start_seconds: Optional[Union[int, float]] = None, - stop_seconds: Optional[Union[int, float]] = None, - delay: Optional[Union[int, float]] = None, - delay_interpolation_method: Union[ + ny: _Optional[int], + start: _Optional[int] = None, + stop: _Optional[int] = None, + start_samples: _Optional[int] = None, + stop_samples: _Optional[int] = None, + start_seconds: _Optional[_Union[int, float]] = None, + stop_seconds: _Optional[_Union[int, float]] = None, + delay: _Optional[_Union[int, float]] = None, + delay_interpolation_method: _Union[ str, _DelayInterpolationMethod ] = _DelayInterpolationMethod.CUBIC, y_scale: float = 1.0, - x_path: Optional[Union[str, Path]] = None, - y_path: Optional[Union[str, Path]] = None, + x_path: _Optional[_Union[str, _Path]] = None, + y_path: _Optional[_Union[str, _Path]] = None, input_gain: float = 0.0, - sample_rate: Optional[float] = None, - require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, + sample_rate: _Optional[float] = None, + require_input_pre_silence: _Optional[ + float + ] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, ): """ :param x: The input signal. A 1D array. @@ -347,7 +357,7 @@ def __init__( self._nx = nx self._ny = ny if ny is not None else len(x) - nx + 1 - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, idx: int) -> _Tuple[_torch.Tensor, _torch.Tensor]: """ :return: Input (NX+NY-1,) @@ -370,11 +380,11 @@ def ny(self) -> int: return self._ny @property - def sample_rate(self) -> Optional[float]: + def sample_rate(self) -> _Optional[float]: return self._sample_rate @property - def x(self) -> torch.Tensor: + def x(self) -> _torch.Tensor: """ The input audio data @@ -383,7 +393,7 @@ def x(self) -> torch.Tensor: return self._x @property - def y(self) -> torch.Tensor: + def y(self) -> _torch.Tensor: """ The output audio data @@ -411,7 +421,7 @@ def parse_config(cls, config): y (torch.Tensor) - loaded from y_path Everything else is passed on to __init__ """ - config = deepcopy(config) + config = _deepcopy(config) sample_rate = config.pop("sample_rate", None) x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate) sample_rate = x_wavinfo.rate @@ -460,11 +470,11 @@ def parse_config(cls, config): @classmethod def _apply_delay( cls, - x: torch.Tensor, - y: torch.Tensor, - delay: Union[int, float], + x: _torch.Tensor, + y: _torch.Tensor, + delay: _Union[int, float], method: _DelayInterpolationMethod, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> _Tuple[_torch.Tensor, _torch.Tensor]: # Check for floats that could be treated like ints (simpler algorithm) if isinstance(delay, float) and int(delay) == delay: delay = int(delay) @@ -477,8 +487,8 @@ def _apply_delay( @classmethod def _apply_delay_int( - cls, x: torch.Tensor, y: torch.Tensor, delay: int - ) -> Tuple[torch.Tensor, torch.Tensor]: + cls, x: _torch.Tensor, y: _torch.Tensor, delay: int + ) -> _Tuple[_torch.Tensor, _torch.Tensor]: if delay > 0: x = x[:-delay] y = y[delay:] @@ -490,12 +500,12 @@ def _apply_delay_int( @classmethod def _apply_delay_float( cls, - x: torch.Tensor, - y: torch.Tensor, + x: _torch.Tensor, + y: _torch.Tensor, delay: float, method: _DelayInterpolationMethod, - ) -> Tuple[torch.Tensor, torch.Tensor]: - n_out = len(y) - int(np.ceil(np.abs(delay))) + ) -> _Tuple[_torch.Tensor, _torch.Tensor]: + n_out = len(y) - int(_np.ceil(_np.abs(delay))) if delay > 0: x = x[:n_out] elif delay < 0: @@ -506,16 +516,16 @@ def _apply_delay_float( @classmethod def _validate_start_stop( cls, - x: torch.Tensor, - y: torch.Tensor, - start: Optional[int] = None, - stop: Optional[int] = None, - start_samples: Optional[int] = None, - stop_samples: Optional[int] = None, - start_seconds: Optional[Union[int, float]] = None, - stop_seconds: Optional[Union[int, float]] = None, - sample_rate: Optional[int] = None, - ) -> Tuple[Optional[int], Optional[int]]: + x: _torch.Tensor, + y: _torch.Tensor, + start: _Optional[int] = None, + stop: _Optional[int] = None, + start_samples: _Optional[int] = None, + stop_samples: _Optional[int] = None, + start_seconds: _Optional[_Union[int, float]] = None, + stop_seconds: _Optional[_Union[int, float]] = None, + sample_rate: _Optional[int] = None, + ) -> _Tuple[_Optional[int], _Optional[int]]: """ Parse the requested start and stop trim points. @@ -639,7 +649,7 @@ def _validate_inputs_after_processing(self, x, y, nx, ny): ) if ny is not None: assert ny <= len(y) - nx + 1 - if torch.abs(y).max() >= 1.0: + if _torch.abs(y).max() >= 1.0: msg = "Output clipped." if self._y_path is not None: msg += f"Source is {self._y_path}" @@ -648,10 +658,10 @@ def _validate_inputs_after_processing(self, x, y, nx, ny): @classmethod def _validate_preceding_silence( cls, - x: torch.Tensor, - start: Optional[int], + x: _torch.Tensor, + start: _Optional[int], silent_seconds: float, - sample_rate: Optional[float], + sample_rate: _Optional[float], ): """ Make sure that the input is silent before the starting index. @@ -677,7 +687,7 @@ def _validate_preceding_silence( raw_check_start = start - silent_samples check_start = max(raw_check_start, 0) if start >= 0 else min(raw_check_start, 0) check_end = start - if not torch.all(x[check_start:check_end] == 0.0): + if not _torch.all(x[check_start:check_end] == 0.0): raise XYError( f"Input provided isn't silent for at least {silent_samples} samples " "before the starting index. Responses to this non-silent input may " @@ -685,15 +695,15 @@ def _validate_preceding_silence( ) -class ConcatDataset(AbstractDataset, InitializableFromConfig): - def __init__(self, datasets: Sequence[Dataset], flatten=True): +class ConcatDataset(AbstractDataset, _InitializableFromConfig): + def __init__(self, datasets: _Sequence[Dataset], flatten=True): if flatten: datasets = self._flatten_datasets(datasets) self._validate_datasets(datasets) self._datasets = datasets self._lookup = self._make_lookup() - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, idx: int) -> _Tuple[_torch.Tensor, _torch.Tensor]: i, j = self._lookup[idx] return self.datasets[i][j] @@ -712,7 +722,7 @@ def parse_config(cls, config): init = _dataset_init_registry[config.get("type", "dataset")] return { "datasets": tuple( - init(c) for c in tqdm(config["dataset_configs"], desc="Loading data") + init(c) for c in _tqdm(config["dataset_configs"], desc="Loading data") ) } @@ -756,8 +766,8 @@ def _make_lookup(self): return lookup @classmethod - def _validate_datasets(cls, datasets: Sequence[Dataset]): - Reference = namedtuple("Reference", ("index", "val")) + def _validate_datasets(cls, datasets: _Sequence[Dataset]): + Reference = _namedtuple("Reference", ("index", "val")) ref_keys, ref_ny = None, None for i, d in enumerate(datasets): ref_ny = Reference(i, d.ny) if ref_ny is None else ref_ny @@ -771,7 +781,7 @@ def _validate_datasets(cls, datasets: Sequence[Dataset]): def register_dataset_initializer( - name: str, constructor: Callable[[Any], AbstractDataset], overwrite=False + name: str, constructor: _Callable[[_Any], AbstractDataset], overwrite=False ): """ If you have other data set types, you can register their initializer by name using diff --git a/nam/models/_activations.py b/nam/models/_activations.py index b6e323a..3912b0e 100644 --- a/nam/models/_activations.py +++ b/nam/models/_activations.py @@ -2,8 +2,8 @@ # Created Date: Friday July 29th 2022 # Author: Steven Atkinson (steven@atkinson.mn) -import torch.nn as nn +import torch.nn as _nn -def get_activation(name: str) -> nn.Module: - return getattr(nn, name)() +def get_activation(name: str) -> _nn.Module: + return getattr(_nn, name)() diff --git a/nam/models/base.py b/nam/models/base.py index 8748f50..09e65b1 100644 --- a/nam/models/base.py +++ b/nam/models/base.py @@ -7,57 +7,64 @@ steps) """ -import abc -import math -import pkg_resources -from typing import Any, Dict, Optional, Tuple, Union +import abc as _abc +import math as _math +import pkg_resources as _pkg_resources +from typing import ( + Any as _Any, + Dict as _Dict, + Optional as _Optional, + Tuple as _Tuple, + Union as _Union, +) -import numpy as np -import torch -import torch.nn as nn +import numpy as _np +import torch as _torch +import torch.nn as _nn -from .._core import InitializableFromConfig -from ..data import wav_to_tensor -from .exportable import Exportable +from .._core import InitializableFromConfig as _InitializableFromConfig +from ..data import wav_to_tensor as _wav_to_tensor +from .exportable import Exportable as _Exportable -class _Base(nn.Module, InitializableFromConfig, Exportable): - def __init__(self, sample_rate: Optional[float] = None): +class _Base(_nn.Module, _InitializableFromConfig, _Exportable): + def __init__(self, sample_rate: _Optional[float] = None): super().__init__() self.register_buffer( - "_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool) + "_has_sample_rate", + _torch.tensor(sample_rate is not None, dtype=_torch.bool), ) self.register_buffer( - "_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate) + "_sample_rate", _torch.tensor(0.0 if sample_rate is None else sample_rate) ) @property - @abc.abstractmethod + @_abc.abstractmethod def pad_start_default(self) -> bool: pass @property - @abc.abstractmethod + @_abc.abstractmethod def receptive_field(self) -> int: """ Receptive field of the model """ pass - @abc.abstractmethod - def forward(self, *args, **kwargs) -> torch.Tensor: + @_abc.abstractmethod + def forward(self, *args, **kwargs) -> _torch.Tensor: pass @classmethod - def _metadata_loudness_x(cls) -> torch.Tensor: - return wav_to_tensor( - pkg_resources.resource_filename( + def _metadata_loudness_x(cls) -> _torch.Tensor: + return _wav_to_tensor( + _pkg_resources.resource_filename( "nam", "models/_resources/loudness_input.wav" ) ) @property - def device(self) -> Optional[torch.device]: + def device(self) -> _Optional[_torch.device]: """ Helpful property, where the parameters of the model live. """ @@ -69,13 +76,13 @@ def device(self) -> Optional[torch.device]: return None @property - def sample_rate(self) -> Optional[float]: + def sample_rate(self) -> _Optional[float]: return self._sample_rate.item() if self._has_sample_rate else None @sample_rate.setter - def sample_rate(self, val: Optional[float]): - self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool) - self._sample_rate = torch.tensor(0.0 if val is None else val) + def sample_rate(self, val: _Optional[float]): + self._has_sample_rate = _torch.tensor(val is not None, dtype=_torch.bool) + self._sample_rate = _torch.tensor(0.0 if val is None else val) def _get_export_dict(self): d = super()._get_export_dict() @@ -97,17 +104,17 @@ def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float: """ x = self._metadata_loudness_x().to(self.device) y = self._at_nominal_settings(gain * x) - loudness = torch.sqrt(torch.mean(torch.square(y))) + loudness = _torch.sqrt(_torch.mean(_torch.square(y))) if db: - loudness = 20.0 * torch.log10(loudness) + loudness = 20.0 * _torch.log10(loudness) return loudness.item() def _metadata_gain(self) -> float: """ Between 0 and 1, how much gain / compression does the model seem to have? """ - x = np.linspace(0.0, 1.0, 11) - y = np.array([self._metadata_loudness(gain=gain, db=False) for gain in x]) + x = _np.linspace(0.0, 1.0, 11) + y = _np.array([self._metadata_loudness(gain=gain, db=False) for gain in x]) # # O ^ o o o o o o # u | o x +-------------------------------------+ @@ -123,14 +130,14 @@ def _metadata_gain(self) -> float: gain_range = max_gain - min_gain this_gain = y.sum() normalized_gain = (this_gain - min_gain) / gain_range - return np.clip(normalized_gain, 0.0, 1.0) + return _np.clip(normalized_gain, 0.0, 1.0) - def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: + def _at_nominal_settings(self, x: _torch.Tensor) -> _torch.Tensor: # parametric?... raise NotImplementedError() - @abc.abstractmethod - def _forward(self, *args) -> torch.Tensor: + @_abc.abstractmethod + def _forward(self, *args) -> _torch.Tensor: """ The true forward method. @@ -139,27 +146,27 @@ def _forward(self, *args) -> torch.Tensor: """ pass - def _export_input_output_args(self) -> Tuple[Any]: + def _export_input_output_args(self) -> _Tuple[_Any]: """ Create any other args necessesary (e.g. params to eval at) """ return () - def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: + def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]: args = self._export_input_output_args() rate = self.sample_rate if rate is None: raise RuntimeError( "Cannot export model's input and output without a sample rate." ) - x = torch.cat( + x = _torch.cat( [ - torch.zeros((rate,)), + _torch.zeros((rate,)), 0.5 - * torch.sin( - 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1] + * _torch.sin( + 2.0 * _math.pi * 220.0 * _torch.linspace(0.0, 1.0, rate + 1)[:-1] ), - torch.zeros((rate,)), + _torch.zeros((rate,)), ] ) # Use pad start to ensure same length as requested by ._export_input_output() @@ -174,14 +181,15 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._mps_65536_fallback = False - def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs): + def forward(self, x: _torch.Tensor, pad_start: _Optional[bool] = None, **kwargs): pad_start = self.pad_start_default if pad_start is None else pad_start scalar = x.ndim == 1 if scalar: x = x[None] if pad_start: - x = torch.cat( - (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 + x = _torch.cat( + (_torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), + dim=1, ) if x.shape[1] < self.receptive_field: raise ValueError( @@ -193,10 +201,10 @@ def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs): y = y[0] return y - def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: + def _at_nominal_settings(self, x: _torch.Tensor) -> _torch.Tensor: return self(x) - def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + def _forward_mps_safe(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor: """ Wrap `._forward()` to protect against MPS-unsupported input lengths beyond 65,536 samples. @@ -213,7 +221,7 @@ def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor: "===WARNING===\n" "NAM encountered a bug in PyTorch's MPS backend and will " "switch to a fallback.\n" - f"Your version of PyTorch is {torch.__version__}.\n" + f"Your version of PyTorch is {_torch.__version__}.\n" "Please report this in an Issue at:\n" "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose" "\n" @@ -236,10 +244,10 @@ def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor: # Bit hacky, but correct. if j == x.shape[1]: break - return torch.cat(out_list, dim=1) + return _torch.cat(out_list, dim=1) - @abc.abstractmethod - def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + @_abc.abstractmethod + def _forward(self, x: _torch.Tensor, **kwargs) -> _torch.Tensor: """ The true forward method. @@ -248,7 +256,7 @@ def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ pass - def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: + def _get_non_user_metadata(self) -> _Dict[str, _Union[str, int, float]]: d = super()._get_non_user_metadata() d["loudness"] = self._metadata_loudness() d["gain"] = self._metadata_gain() diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py index 046af6f..71aafd8 100644 --- a/nam/models/conv_net.py +++ b/nam/models/conv_net.py @@ -2,28 +2,37 @@ # Created Date: Saturday February 5th 2022 # Author: Steven Atkinson (steven@atkinson.mn) -import json -import math -from enum import Enum -from functools import partial -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Optional, Sequence, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F +import json as _json +import math as _math +from enum import Enum as _Enum +from functools import partial as _partial +from pathlib import Path as _Path +from tempfile import TemporaryDirectory as _TemporaryDirectory +from typing import ( + Optional as _Optional, + Sequence as _Sequence, + Tuple as _Tuple, + Union as _Union, +) + +import numpy as _np +import torch as _torch +import torch.nn as _nn +import torch.nn.functional as _F from .. import __version__ -from ..data import wav_to_tensor -from ._activations import get_activation -from .base import BaseNet -from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME +from ..data import wav_to_tensor as _wav_to_tensor +from ._activations import get_activation as _get_activation +from .base import BaseNet as _BaseNet +from ._names import ( + ACTIVATION_NAME as _ACTIVATION_NAME, + BATCHNORM_NAME as _BATCHNORM_NAME, + CONV_NAME as _CONV_NAME, +) -class TrainStrategy(Enum): +class TrainStrategy(_Enum): STRIDE = "stride" DILATE = "dilate" @@ -31,7 +40,7 @@ class TrainStrategy(Enum): default_train_strategy = TrainStrategy.DILATE -class _Functional(nn.Module): +class _Functional(_nn.Module): """ Define a layer by a function w/ no params """ @@ -44,37 +53,37 @@ def forward(self, *args, **kwargs): return self._op(*args, **kwargs) -class _IR(nn.Module): - def __init__(self, filename: Union[str, Path]): +class _IR(_nn.Module): + def __init__(self, filename: _Union[str, _Path]): super().__init__() - self.register_buffer("_weight", reversed(wav_to_tensor(filename))[None, None]) + self.register_buffer("_weight", reversed(_wav_to_tensor(filename))[None, None]) @property def length(self) -> int: return self._weight.shape[-1] - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: _torch.Tensor) -> _torch.Tensor: """ :param x: (N,D) :return: (N,D-length+1) """ - return F.conv1d(x[:, None], self._weight)[:, 0] + return _F.conv1d(x[:, None], self._weight)[:, 0] def _conv_net( channels: int = 32, - dilations: Sequence[int] = None, + dilations: _Sequence[int] = None, batchnorm: bool = False, activation: str = "Tanh", -) -> nn.Sequential: +) -> _nn.Sequential: def block(cin, cout, dilation): - net = nn.Sequential() + net = _nn.Sequential() net.add_module( - CONV_NAME, nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm) + _CONV_NAME, _nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm) ) if batchnorm: - net.add_module(BATCHNORM_NAME, nn.BatchNorm1d(cout)) - net.add_module(ACTIVATION_NAME, get_activation(activation)) + net.add_module(_BATCHNORM_NAME, _nn.BatchNorm1d(cout)) + net.add_module(_ACTIVATION_NAME, _get_activation(activation)) return net def check_and_expand(n, x): @@ -86,19 +95,19 @@ def check_and_expand(n, x): dilations = [1, 2, 4, 8] if dilations is None else dilations receptive_field = sum(dilations) + 1 - net = nn.Sequential() - net.add_module("expand", _Functional(partial(check_and_expand, receptive_field))) + net = _nn.Sequential() + net.add_module("expand", _Functional(_partial(check_and_expand, receptive_field))) cin = 1 cout = channels for i, dilation in enumerate(dilations): net.add_module(f"block_{i}", block(cin, cout, dilation)) cin = cout - net.add_module("head", nn.Conv1d(channels, 1, 1)) - net.add_module("flatten", nn.Flatten()) + net.add_module("head", _nn.Conv1d(channels, 1, 1)) + net.add_module("flatten", _nn.Flatten()) return net -class ConvNet(BaseNet): +class ConvNet(_BaseNet): """ A straightforward convolutional neural network. @@ -109,8 +118,8 @@ def __init__( self, *args, train_strategy: TrainStrategy = default_train_strategy, - ir: Optional[_IR] = None, - sample_rate: Optional[float] = None, + ir: _Optional[_IR] = None, + sample_rate: _Optional[float] = None, **kwargs, ): super().__init__(sample_rate=sample_rate) @@ -149,12 +158,12 @@ def receptive_field(self) -> int: @property def _activation(self): return ( - self._net._modules["block_0"]._modules[ACTIVATION_NAME].__class__.__name__ + self._net._modules["block_0"]._modules[_ACTIVATION_NAME].__class__.__name__ ) @property def _channels(self) -> int: - return self._net._modules["block_0"]._modules[CONV_NAME].weight.shape[0] + return self._net._modules["block_0"]._modules[_CONV_NAME].weight.shape[0] @property def _num_layers(self) -> int: @@ -162,14 +171,14 @@ def _num_layers(self) -> int: @property def _batchnorm(self) -> bool: - return BATCHNORM_NAME in self._net._modules["block_0"]._modules - - def export_cpp_header(self, filename: Path): - with TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - self.export(Path(tmpdir)) - with open(Path(tmpdir, "config.json"), "r") as fp: - _c = json.load(fp) + return _BATCHNORM_NAME in self._net._modules["block_0"]._modules + + def export_cpp_header(self, filename: _Path): + with _TemporaryDirectory() as tmpdir: + tmpdir = _Path(tmpdir) + self.export(_Path(tmpdir)) + with open(_Path(tmpdir, "config.json"), "r") as fp: + _c = _json.load(fp) version = _c["version"] config = _c["config"] with open(filename, "w") as f: @@ -187,7 +196,10 @@ def export_cpp_header(self, filename: Path): f"const std::string ACTIVATION = \"{config['activation']}\";\n", "std::vector PARAMS{" + ",".join( - [f"{w:.16f}" for w in np.load(Path(tmpdir, "weights.npy"))] + [ + f"{w:.16f}" + for w in _np.load(_Path(tmpdir, "weights.npy")) + ] ) + "};\n", ) @@ -201,11 +213,11 @@ def _export_config(self): "activation": self._activation, } - def _export_input_output(self, x=None) -> Tuple[np.ndarray, np.ndarray]: + def _export_input_output(self, x=None) -> _Tuple[_np.ndarray, _np.ndarray]: """ :return: (L,), (L,) """ - with torch.no_grad(): + with _torch.no_grad(): training = self.training self.eval() x = self._export_input_signal() if x is None else x @@ -222,18 +234,18 @@ def _export_input_signal(self): raise RuntimeError( "Cannot export model's input and output without a sample rate." ) - return torch.cat( + return _torch.cat( [ - torch.zeros((rate,)), + _torch.zeros((rate,)), 0.5 - * torch.sin( - 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1] + * _torch.sin( + 2.0 * _math.pi * 220.0 * _torch.linspace(0.0, 1.0, rate + 1)[:-1] ), - torch.zeros((rate,)), + _torch.zeros((rate,)), ] ) - def _export_weights(self) -> np.ndarray: + def _export_weights(self) -> _np.ndarray: """ weights are serialized to weights.npy in the following order: * (expand: no params) @@ -256,21 +268,21 @@ def _export_weights(self) -> np.ndarray: for i in range(self._num_layers): block_name = f"block_{i}" block = self._net._modules[block_name] - conv = block._modules[CONV_NAME] + conv = block._modules[_CONV_NAME] params.append(conv.weight.flatten()) if conv.bias is not None: params.append(conv.bias.flatten()) if self._batchnorm: - bn = block._modules[BATCHNORM_NAME] + bn = block._modules[_BATCHNORM_NAME] params.append(bn.running_mean.flatten()) params.append(bn.running_var.flatten()) params.append(bn.weight.flatten()) params.append(bn.bias.flatten()) - params.append(torch.Tensor([bn.eps]).to(bn.weight.device)) + params.append(_torch.Tensor([bn.eps]).to(bn.weight.device)) head = self._net._modules["head"] params.append(head.weight.flatten()) params.append(head.bias.flatten()) - params = torch.cat(params).detach().cpu().numpy() + params = _torch.cat(params).detach().cpu().numpy() return params def _forward(self, x): @@ -279,13 +291,13 @@ def _forward(self, x): y = self._ir(y) return y - def _get_dilations(self) -> Tuple[int]: + def _get_dilations(self) -> _Tuple[int]: return tuple( - self._net._modules[f"block_{i}"]._modules[CONV_NAME].dilation[0] + self._net._modules[f"block_{i}"]._modules[_CONV_NAME].dilation[0] for i in range(self._num_blocks) ) - def _get_num_blocks(self, net: nn.Sequential): + def _get_num_blocks(self, net: _nn.Sequential): i = 0 while True: if f"block_{i}" not in net._modules: diff --git a/nam/models/exportable.py b/nam/models/exportable.py index 6bef7c5..e75b530 100644 --- a/nam/models/exportable.py +++ b/nam/models/exportable.py @@ -2,32 +2,39 @@ # Created Date: Tuesday February 8th 2022 # Author: Steven Atkinson (steven@atkinson.mn) -import abc -import json -import logging -from datetime import datetime -from enum import Enum -from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, Union - -import numpy as np - -from .metadata import Date, UserMetadata - -logger = logging.getLogger(__name__) +import abc as _abc +import json as _json +import logging as _logging +from datetime import datetime as _datetime +from enum import Enum as _Enum +from pathlib import Path as _Path +from typing import ( + Any as _Any, + Dict as _Dict, + Optional as _Optional, + Sequence as _Sequence, + Tuple as _Tuple, + Union as _Union, +) + +import numpy as _np + +from .metadata import Date as _Date, UserMetadata as _UserMetadata + +logger = _logging.getLogger(__name__) # Model version is independent from package version as of package version 0.5.2 so that # the API of the package can iterate at a different pace from that of the model files. _MODEL_VERSION = "0.5.4" -def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]: +def _cast_enums(d: _Dict[_Any, _Any]) -> _Dict[_Any, _Any]: """ Casts enum-type keys to their values """ out = {} for key, val in d.items(): - if isinstance(val, Enum): + if isinstance(val, _Enum): val = val.value if isinstance(val, dict): val = _cast_enums(val) @@ -35,7 +42,7 @@ def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]: return out -class Exportable(abc.ABC): +class Exportable(_abc.ABC): """ Interface for my custon export format for use in the plugin. """ @@ -44,11 +51,11 @@ class Exportable(abc.ABC): def export( self, - outdir: Path, + outdir: _Path, include_snapshot: bool = False, basename: str = "model", - user_metadata: Optional[UserMetadata] = None, - other_metadata: Optional[dict] = None, + user_metadata: _Optional[_UserMetadata] = None, + other_metadata: _Optional[dict] = None, ): """ Interface for exporting. @@ -81,29 +88,29 @@ def export( training = self.training self.eval() - with open(Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp: - json.dump(model_dict, fp) + with open(_Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp: + _json.dump(model_dict, fp) if include_snapshot: x, y = self._export_input_output() - x_path = Path(outdir, "test_inputs.npy") - y_path = Path(outdir, "test_outputs.npy") + x_path = _Path(outdir, "test_inputs.npy") + y_path = _Path(outdir, "test_outputs.npy") logger.debug(f"Saving snapshot input to {x_path}") - np.save(x_path, x) + _np.save(x_path, x) logger.debug(f"Saving snapshot output to {y_path}") - np.save(y_path, y) + _np.save(y_path, y) # And resume training state self.train(training) - @abc.abstractmethod - def export_cpp_header(self, filename: Path): + @_abc.abstractmethod + def export_cpp_header(self, filename: _Path): """ Export a .h file to compile into the plugin with the weights written right out as text """ pass - def export_onnx(self, filename: Path): + def export_onnx(self, filename: _Path): """ Export model in format for ONNX Runtime """ @@ -112,7 +119,7 @@ def export_onnx(self, filename: Path): f"{self.__class__.__name__}" ) - def import_weights(self, weights: Sequence[float]): + def import_weights(self, weights: _Sequence[float]): """ Inverse of `._export_weights() """ @@ -121,7 +128,7 @@ def import_weights(self, weights: Sequence[float]): "implemented yet." ) - @abc.abstractmethod + @_abc.abstractmethod def _export_config(self): """ Creates the JSON of the model's archtecture hyperparameters (number of layers, @@ -131,8 +138,8 @@ def _export_config(self): """ pass - @abc.abstractmethod - def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: + @_abc.abstractmethod + def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]: """ Create an input and corresponding output signal to verify its behavior. @@ -141,8 +148,8 @@ def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: """ pass - @abc.abstractmethod - def _export_weights(self) -> np.ndarray: + @_abc.abstractmethod + def _export_weights(self) -> _np.ndarray: """ Flatten the weights out to a 1D array """ @@ -157,13 +164,13 @@ def _get_export_dict(self): "weights": self._export_weights().tolist(), } - def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: + def _get_non_user_metadata(self) -> _Dict[str, _Union[str, int, float]]: """ Get any metadata that's non-user-provided (date, loudness, gain) """ - t = datetime.now() + t = _datetime.now() return { - "date": Date( + "date": _Date( year=t.year, month=t.month, day=t.day, diff --git a/nam/models/linear.py b/nam/models/linear.py index ff7565a..b0dd911 100644 --- a/nam/models/linear.py +++ b/nam/models/linear.py @@ -6,18 +6,18 @@ Linear model """ -import numpy as np -import torch -import torch.nn as nn +import numpy as _np +import torch as _torch +import torch.nn as _nn from .._version import __version__ -from .base import BaseNet +from .base import BaseNet as _BaseNet -class Linear(BaseNet): +class Linear(_BaseNet): def __init__(self, receptive_field: int, *args, bias: bool = False, **kwargs): super().__init__(*args, **kwargs) - self._net = nn.Conv1d(1, 1, receptive_field, bias=bias) + self._net = _nn.Conv1d(1, 1, receptive_field, bias=bias) @property def pad_start_default(self) -> bool: @@ -34,7 +34,7 @@ def export_cpp_header(self): def _bias(self) -> bool: return self._net.bias is not None - def _forward(self, x: torch.Tensor) -> torch.Tensor: + def _forward(self, x: _torch.Tensor) -> _torch.Tensor: return self._net(x[:, None])[:, 0] def _export_config(self): @@ -43,9 +43,9 @@ def _export_config(self): "bias": self._bias, } - def _export_weights(self) -> np.ndarray: + def _export_weights(self) -> _np.ndarray: params_list = [self._net.weight.flatten()] if self._bias: params_list.append(self._net.bias.flatten()) - params = torch.cat(params_list).detach().cpu().numpy() + params = _torch.cat(params_list).detach().cpu().numpy() return params diff --git a/nam/models/losses.py b/nam/models/losses.py index 31a2281..d49bc64 100644 --- a/nam/models/losses.py +++ b/nam/models/losses.py @@ -6,13 +6,13 @@ Loss functions """ -from typing import Optional +from typing import Optional as _Optional -import torch -from auraloss.freq import MultiResolutionSTFTLoss +import torch as _torch +from auraloss.freq import MultiResolutionSTFTLoss as _MultiResolutionSTFTLoss -def apply_pre_emphasis_filter(x: torch.Tensor, coef: float) -> torch.Tensor: +def apply_pre_emphasis_filter(x: _torch.Tensor, coef: float) -> _torch.Tensor: """ Apply first-order pre-emphsis filter @@ -24,7 +24,7 @@ def apply_pre_emphasis_filter(x: torch.Tensor, coef: float) -> torch.Tensor: return x[..., 1:] - coef * x[..., :-1] -def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: +def esr(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor: """ ESR of (a batch of) predictions & targets @@ -42,18 +42,18 @@ def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: raise ValueError( f"Expect 2D targets (batch_size, num_samples). Got {targets.shape}" ) - return torch.mean( - torch.mean(torch.square(preds - targets), dim=1) - / torch.mean(torch.square(targets), dim=1) + return _torch.mean( + _torch.mean(_torch.square(preds - targets), dim=1) + / _torch.mean(_torch.square(targets), dim=1) ) def multi_resolution_stft_loss( - preds: torch.Tensor, - targets: torch.Tensor, - loss_func: Optional[MultiResolutionSTFTLoss] = None, - device: Optional[torch.device] = None, -) -> torch.Tensor: + preds: _torch.Tensor, + targets: _torch.Tensor, + loss_func: _Optional[_MultiResolutionSTFTLoss] = None, + device: _Optional[_torch.device] = None, +) -> _torch.Tensor: """ Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation. B: Batch size @@ -66,13 +66,13 @@ def multi_resolution_stft_loss( :param device: If provided, send the preds and targets to the provided device. :return: () """ - loss_func = MultiResolutionSTFTLoss() if loss_func is None else loss_func + loss_func = _MultiResolutionSTFTLoss() if loss_func is None else loss_func if device is not None: preds, targets = [z.to(device) for z in (preds, targets)] return loss_func(preds, targets) -def mse_fft(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: +def mse_fft(preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor: """ Fourier loss @@ -80,7 +80,7 @@ def mse_fft(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: :param targets: Same as preds :return: () """ - fp = torch.fft.fft(preds) - ft = torch.fft.fft(targets) + fp = _torch.fft.fft(preds) + ft = _torch.fft.fft(targets) e = fp - ft - return torch.mean(torch.square(e.abs())) + return _torch.mean(_torch.square(e.abs())) diff --git a/nam/models/metadata.py b/nam/models/metadata.py index 9192035..d481052 100644 --- a/nam/models/metadata.py +++ b/nam/models/metadata.py @@ -6,14 +6,14 @@ Metadata about models """ -from enum import Enum -from typing import Optional +from enum import Enum as _Enum +from typing import Optional as _Optional -from pydantic import BaseModel +from pydantic import BaseModel as _BaseModel # Note: if you change this enum, you need to update the options in easy_colab.ipynb! -class GearType(Enum): +class GearType(_Enum): AMP = "amp" PEDAL = "pedal" PEDAL_AMP = "pedal_amp" @@ -24,7 +24,7 @@ class GearType(Enum): # Note: if you change this enum, you need to update the options in easy_colab.ipynb! -class ToneType(Enum): +class ToneType(_Enum): CLEAN = "clean" OVERDRIVE = "overdrive" CRUNCH = "crunch" @@ -32,7 +32,7 @@ class ToneType(Enum): FUZZ = "fuzz" -class Date(BaseModel): +class Date(_BaseModel): year: int month: int day: int @@ -41,7 +41,7 @@ class Date(BaseModel): second: int -class UserMetadata(BaseModel): +class UserMetadata(_BaseModel): """ Metadata that users provide for a NAM model @@ -57,11 +57,11 @@ class UserMetadata(BaseModel): the model. """ - name: Optional[str] = None - modeled_by: Optional[str] = None - gear_type: Optional[GearType] = None - gear_make: Optional[str] = None - gear_model: Optional[str] = None - tone_type: Optional[ToneType] = None - input_level_dbu: Optional[float] = None - output_level_dbu: Optional[float] = None + name: _Optional[str] = None + modeled_by: _Optional[str] = None + gear_type: _Optional[GearType] = None + gear_make: _Optional[str] = None + gear_model: _Optional[str] = None + tone_type: _Optional[ToneType] = None + input_level_dbu: _Optional[float] = None + output_level_dbu: _Optional[float] = None diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py index 991165f..96dde8d 100644 --- a/nam/models/recurrent.py +++ b/nam/models/recurrent.py @@ -8,20 +8,20 @@ TODO batch_first=False (I get it...) """ -import abc -import json -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Optional, Tuple +import abc as _abc +import json as _json +from pathlib import Path as _Path +from tempfile import TemporaryDirectory as _TemporaryDirectory +from typing import Optional as Optional, Tuple as _Tuple -import numpy as np -import torch -import torch.nn as nn +import numpy as _np +import torch as _torch +import torch.nn as _nn -from .base import BaseNet +from .base import BaseNet as _BaseNet -class _L(nn.LSTM): +class _L(_nn.LSTM): """ Tweaks to PyTorch LSTM module * Up the remembering @@ -47,24 +47,24 @@ def reset_parameters(self) -> None: # DH: Hidden state dimension # [0]: hidden (L,DH) # [1]: cell (L,DH) -_LSTMHiddenType = torch.Tensor -_LSTMCellType = torch.Tensor -_LSTMHiddenCellType = Tuple[_LSTMHiddenType, _LSTMCellType] +_LSTMHiddenType = _torch.Tensor +_LSTMCellType = _torch.Tensor +_LSTMHiddenCellType = _Tuple[_LSTMHiddenType, _LSTMCellType] # TODO get this somewhere more core-ish -class _ExportsWeights(abc.ABC): - @abc.abstractmethod - def export_weights(self) -> np.ndarray: +class _ExportsWeights(_abc.ABC): + @_abc.abstractmethod + def export_weights(self) -> _np.ndarray: """ :return: a 1D array of weights """ pass -class _Linear(nn.Linear, _ExportsWeights): +class _Linear(_nn.Linear, _ExportsWeights): def export_weights(self): - return np.concatenate( + return _np.concatenate( [ self.weight.data.detach().cpu().numpy().flatten(), self.bias.data.detach().cpu().numpy().flatten(), @@ -72,7 +72,7 @@ def export_weights(self): ) -class LSTM(BaseNet): +class LSTM(_BaseNet): """ ABC for recurrent architectures """ @@ -105,16 +105,16 @@ def __init__( self._head = self._init_head(hidden_size) self._train_burn_in = train_burn_in self._train_truncate = train_truncate - self._initial_cell = nn.Parameter( - torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size)) + self._initial_cell = _nn.Parameter( + _torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size)) ) - self._initial_hidden = nn.Parameter( - torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size)) + self._initial_hidden = _nn.Parameter( + _torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size)) ) self._get_initial_state_burn_in = 48_000 @property - def input_device(self) -> torch.device: + def input_device(self) -> _torch.device: """ What device does the input need to be on? """ @@ -129,12 +129,12 @@ def pad_start_default(self) -> bool: # I should simplify this... return True - def export_cpp_header(self, filename: Path): - with TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - LSTM.export(self, Path(tmpdir)) # Hacky...need to work w/ CatLSTM - with open(Path(tmpdir, "model.nam"), "r") as fp: - _c = json.load(fp) + def export_cpp_header(self, filename: _Path): + with _TemporaryDirectory() as tmpdir: + tmpdir = _Path(tmpdir) + LSTM.export(self, _Path(tmpdir)) # Hacky...need to work w/ CatLSTM + with open(_Path(tmpdir, "model.nam"), "r") as fp: + _c = _json.load(fp) version = _c["version"] config = _c["config"] s_parametric = self._export_cpp_header_parametric(config.get("parametric")) @@ -159,7 +159,7 @@ def export_cpp_header(self, filename: Path): ) ) - def _apply_head(self, features: torch.Tensor) -> torch.Tensor: + def _apply_head(self, features: _torch.Tensor) -> _torch.Tensor: """ :param features: (B,S,DH) :return: (B,S) @@ -167,8 +167,8 @@ def _apply_head(self, features: torch.Tensor) -> torch.Tensor: return self._head(features)[:, :, 0] def _forward( - self, x: torch.Tensor, initial_state: Optional[_LSTMHiddenCellType] = None - ) -> torch.Tensor: + self, x: _torch.Tensor, initial_state: Optional[_LSTMHiddenCellType] = None + ) -> _torch.Tensor: """ :param x: (B,L) or (B,L,D) :return: (B,L) @@ -183,7 +183,7 @@ def process_in_blocks(x, hidden_state=None): x[:, i : i + BLOCK_SIZE, :], hidden_state ) outputs.append(out) - return torch.cat(outputs, dim=1), hidden_state # assert batch_first + return _torch.cat(outputs, dim=1), hidden_state # assert batch_first last_hidden_state = ( self._initial_state(len(x)) if initial_state is None else initial_state @@ -208,12 +208,12 @@ def process_in_blocks(x, hidden_state=None): x[:, i : i + self._train_truncate, :], last_hidden_state ) output_features_list.append(last_output_features) - output_features = torch.cat(output_features_list, dim=1) + output_features = _torch.cat(output_features_list, dim=1) return self._apply_head(output_features) def _export_cell_weights( - self, i: int, hidden_state: torch.Tensor, cell_state: torch.Tensor - ) -> np.ndarray: + self, i: int, hidden_state: _torch.Tensor, cell_state: _torch.Tensor + ) -> _np.ndarray: """ * weight matrix (xh -> ifco) * bias vector @@ -222,7 +222,7 @@ def _export_cell_weights( """ tensors = [ - torch.cat( + _torch.cat( [ getattr(self._core, f"weight_ih_l{i}").data, getattr(self._core, f"weight_hh_l{i}").data, @@ -234,7 +234,7 @@ def _export_cell_weights( hidden_state, cell_state, ] - return np.concatenate([z.detach().cpu().numpy().flatten() for z in tensors]) + return _np.concatenate([z.detach().cpu().numpy().flatten() for z in tensors]) def _export_config(self): return { @@ -259,7 +259,7 @@ def _export_weights(self): * Head weights * Head bias """ - return np.concatenate( + return _np.concatenate( [ self._export_cell_weights(i, h, c) for i, (h, c) in enumerate(zip(*self._get_initial_state())) @@ -279,7 +279,7 @@ def _get_initial_state(self, inputs=None) -> _LSTMHiddenCellType: :return: (L,DH), (L,DH) """ inputs = ( - torch.zeros((1, self._get_initial_state_burn_in, 1)) + _torch.zeros((1, self._get_initial_state_burn_in, 1)) if inputs is None else inputs ).to(self.input_device) @@ -298,7 +298,7 @@ def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType: (self._initial_hidden, self._initial_cell) if n is None else ( - torch.tile(self._initial_hidden[:, None], (1, n, 1)), - torch.tile(self._initial_cell[:, None], (1, n, 1)), + _torch.tile(self._initial_hidden[:, None], (1, n, 1)), + _torch.tile(self._initial_cell[:, None], (1, n, 1)), ) ) diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py index fe284d4..fc9f500 100644 --- a/nam/models/wavenet.py +++ b/nam/models/wavenet.py @@ -7,34 +7,39 @@ https://arxiv.org/abs/1609.03499 """ -import json -from copy import deepcopy -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Dict, Optional, Sequence, Tuple - -import numpy as np -import torch -import torch.nn as nn - -from ._activations import get_activation -from .base import BaseNet -from ._names import ACTIVATION_NAME, CONV_NAME - - -class Conv1d(nn.Conv1d): - def export_weights(self) -> torch.Tensor: +import json as _json +from copy import deepcopy as _deepcopy +from pathlib import Path as _Path +from tempfile import TemporaryDirectory as _TemporaryDirectory +from typing import ( + Dict as _Dict, + Optional as _Optional, + Sequence as _Sequence, + Tuple as _Tuple, +) + +import numpy as _np +import torch as _torch +import torch.nn as _nn + +from ._activations import get_activation as _get_activation +from .base import BaseNet as _BaseNet +from ._names import ACTIVATION_NAME as _ACTIVATION_NAME, CONV_NAME as _CONV_NAME + + +class Conv1d(_nn.Conv1d): + def export_weights(self) -> _torch.Tensor: tensors = [] if self.weight is not None: tensors.append(self.weight.data.flatten()) if self.bias is not None: tensors.append(self.bias.data.flatten()) if len(tensors) == 0: - return torch.zeros((0,)) + return _torch.zeros((0,)) else: - return torch.cat(tensors) + return _torch.cat(tensors) - def import_weights(self, weights: torch.Tensor, i: int) -> int: + def import_weights(self, weights: _torch.Tensor, i: int) -> int: if self.weight is not None: n = self.weight.numel() self.weight.data = ( @@ -50,7 +55,7 @@ def import_weights(self, weights: torch.Tensor, i: int) -> int: return i -class _Layer(nn.Module): +class _Layer(_nn.Module): def __init__( self, condition_size: int, @@ -67,7 +72,7 @@ def __init__( # Custom init: favors direct input-output # self._conv.weight.data.zero_() self._input_mixer = Conv1d(condition_size, mid_channels, 1, bias=False) - self._activation = get_activation(activation) + self._activation = _get_activation(activation) self._activation_name = activation self._1x1 = Conv1d(channels, channels, 1) self._gated = gated @@ -88,8 +93,8 @@ def gated(self) -> bool: def kernel_size(self) -> int: return self._conv.kernel_size[0] - def export_weights(self) -> torch.Tensor: - return torch.cat( + def export_weights(self) -> _torch.Tensor: + return _torch.cat( [ self.conv.export_weights(), self._input_mixer.export_weights(), @@ -98,8 +103,8 @@ def export_weights(self) -> torch.Tensor: ) def forward( - self, x: torch.Tensor, h: Optional[torch.Tensor], out_length: int - ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + self, x: _torch.Tensor, h: _Optional[_torch.Tensor], out_length: int + ) -> _Tuple[_Optional[_torch.Tensor], _torch.Tensor]: """ :param x: (B,C,L1) From last layer :param h: (B,DX,L2) Conditioning. If first, ignored. @@ -117,7 +122,7 @@ def forward( if not self._gated else ( self._activation(z1[:, : self._channels]) - * torch.sigmoid(z1[:, self._channels :]) + * _torch.sigmoid(z1[:, self._channels :]) ) ) return ( @@ -125,7 +130,7 @@ def forward( post_activation[:, :, -out_length:], ) - def import_weights(self, weights: torch.Tensor, i: int) -> int: + def import_weights(self, weights: _torch.Tensor, i: int) -> int: i = self.conv.import_weights(weights, i) i = self._input_mixer.import_weights(weights, i) return self._1x1.import_weights(weights, i) @@ -135,7 +140,7 @@ def _channels(self) -> int: return self._1x1.in_channels -class _Layers(nn.Module): +class _Layers(_nn.Module): """ Takes in the input and condition (and maybe the head input so far); outputs the layer output and head input. @@ -152,14 +157,14 @@ def __init__( head_size, channels: int, kernel_size: int, - dilations: Sequence[int], + dilations: _Sequence[int], activation: str = "Tanh", gated: bool = True, head_bias: bool = True, ): super().__init__() self._rechannel = Conv1d(input_size, channels, 1, bias=False) - self._layers = nn.ModuleList( + self._layers = _nn.ModuleList( [ _Layer( condition_size, channels, kernel_size, dilation, activation, gated @@ -187,16 +192,16 @@ def receptive_field(self) -> int: return 1 + (self._kernel_size - 1) * sum(self._dilations) def export_config(self): - return deepcopy(self._config) + return _deepcopy(self._config) - def export_weights(self) -> torch.Tensor: - return torch.cat( + def export_weights(self) -> _torch.Tensor: + return _torch.cat( [self._rechannel.export_weights()] + [layer.export_weights() for layer in self._layers] + [self._head_rechannel.export_weights()] ) - def import_weights(self, weights: torch.Tensor, i: int) -> int: + def import_weights(self, weights: _torch.Tensor, i: int) -> int: i = self._rechannel.import_weights(weights, i) for layer in self._layers: i = layer.import_weights(weights, i) @@ -204,10 +209,10 @@ def import_weights(self, weights: torch.Tensor, i: int) -> int: def forward( self, - x: torch.Tensor, - c: torch.Tensor, - head_input: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + x: _torch.Tensor, + c: _torch.Tensor, + head_input: _Optional[_torch.Tensor] = None, + ) -> _Tuple[_torch.Tensor, _torch.Tensor]: """ :param x: (B,Dx,L) layer input :param c: (B,Dc,L) condition @@ -228,7 +233,7 @@ def forward( return self._head_rechannel(head_input), x @property - def _dilations(self) -> Sequence[int]: + def _dilations(self) -> _Sequence[int]: return self._config["dilations"] @property @@ -236,7 +241,7 @@ def _kernel_size(self) -> int: return self._layers[0].kernel_size -class _Head(nn.Module): +class _Head(_nn.Module): def __init__( self, in_channels: int, @@ -248,14 +253,14 @@ def __init__( super().__init__() def block(cx, cy): - net = nn.Sequential() - net.add_module(ACTIVATION_NAME, get_activation(activation)) - net.add_module(CONV_NAME, Conv1d(cx, cy, 1)) + net = _nn.Sequential() + net.add_module(_ACTIVATION_NAME, _get_activation(activation)) + net.add_module(_CONV_NAME, Conv1d(cx, cy, 1)) return net assert num_layers > 0 - layers = nn.Sequential() + layers = _nn.Sequential() cin = in_channels for i in range(num_layers): layers.add_module( @@ -273,30 +278,30 @@ def block(cx, cy): } def export_config(self): - return deepcopy(self._config) + return _deepcopy(self._config) - def export_weights(self) -> torch.Tensor: - return torch.cat([layer[1].export_weights() for layer in self._layers]) + def export_weights(self) -> _torch.Tensor: + return _torch.cat([layer[1].export_weights() for layer in self._layers]) def forward(self, *args, **kwargs): return self._layers(*args, **kwargs) - def import_weights(self, weights: torch.Tensor, i: int) -> int: + def import_weights(self, weights: _torch.Tensor, i: int) -> int: for layer in self._layers: i = layer[1].import_weights(weights, i) return i -class _WaveNet(nn.Module): +class _WaveNet(_nn.Module): def __init__( self, - layers_configs: Sequence[Dict], - head_config: Optional[Dict] = None, + layers_configs: _Sequence[_Dict], + head_config: _Optional[_Dict] = None, head_scale: float = 1.0, ): super().__init__() - self._layers = nn.ModuleList([_Layers(**lc) for lc in layers_configs]) + self._layers = _nn.ModuleList([_Layers(**lc) for lc in layers_configs]) self._head = None if head_config is None else _Head(**head_config) self._head_scale = head_scale @@ -311,22 +316,22 @@ def export_config(self): "head_scale": self._head_scale, } - def export_weights(self) -> np.ndarray: + def export_weights(self) -> _np.ndarray: """ :return: 1D array """ - weights = torch.cat([layer.export_weights() for layer in self._layers]) + weights = _torch.cat([layer.export_weights() for layer in self._layers]) if self._head is not None: - weights = torch.cat([weights, self._head.export_weights()]) - weights = torch.cat([weights.cpu(), torch.Tensor([self._head_scale])]) + weights = _torch.cat([weights, self._head.export_weights()]) + weights = _torch.cat([weights.cpu(), _torch.Tensor([self._head_scale])]) return weights.detach().cpu().numpy() - def import_weights(self, weights: torch.Tensor): + def import_weights(self, weights: _torch.Tensor): i = 0 for layer in self._layers: i = layer.import_weights(weights, i) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: _torch.Tensor) -> _torch.Tensor: """ :param x: (B,Cx,L) :return: (B,Cy,L-R) @@ -338,8 +343,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return head_input if self._head is None else self._head(head_input) -class WaveNet(BaseNet): - def __init__(self, *args, sample_rate: Optional[float] = None, **kwargs): +class WaveNet(_BaseNet): + def __init__(self, *args, sample_rate: _Optional[float] = None, **kwargs): super().__init__(sample_rate=sample_rate) self._net = _WaveNet(*args, **kwargs) @@ -351,12 +356,12 @@ def pad_start_default(self) -> bool: def receptive_field(self) -> int: return self._net.receptive_field - def export_cpp_header(self, filename: Path): - with TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - WaveNet.export(self, Path(tmpdir)) # Hacky...need to work w/ CatWaveNet - with open(Path(tmpdir, "model.nam"), "r") as fp: - _c = json.load(fp) + def export_cpp_header(self, filename: _Path): + with _TemporaryDirectory() as tmpdir: + tmpdir = _Path(tmpdir) + WaveNet.export(self, _Path(tmpdir)) # Hacky...need to work w/ CatWaveNet + with open(_Path(tmpdir, "model.nam"), "r") as fp: + _c = _json.load(fp) version = _c["version"] config = _c["config"] @@ -412,9 +417,9 @@ def export_cpp_header(self, filename: Path): ) ) - def import_weights(self, weights: Sequence[float]): - if not isinstance(weights, torch.Tensor): - weights = torch.Tensor(weights) + def import_weights(self, weights: _Sequence[float]): + if not isinstance(weights, _torch.Tensor): + weights = _torch.Tensor(weights) self._net.import_weights(weights) def _export_config(self): @@ -425,7 +430,7 @@ def _export_cpp_header_parametric(self, config): raise ValueError("Got non-None parametric config") return ("nlohmann::json PARAMETRIC {};\n",) - def _export_weights(self) -> np.ndarray: + def _export_weights(self) -> _np.ndarray: return self._net.export_weights() def _forward(self, x): diff --git a/nam/train/_names.py b/nam/train/_names.py index e2cd1e5..9ee9825 100644 --- a/nam/train/_names.py +++ b/nam/train/_names.py @@ -2,15 +2,15 @@ # Created Date: Monday November 6th 2023 # Author: Steven Atkinson (steven@atkinson.mn) -from typing import NamedTuple, Optional, Set +from typing import NamedTuple as _NamedTuple, Optional as _Optional, Set as _Set -from ._version import PROTEUS_VERSION, Version +from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version -class VersionAndName(NamedTuple): +class VersionAndName(_NamedTuple): version: Version name: str - other_names: Optional[Set[str]] + other_names: _Optional[_Set[str]] # From most- to the least-recently-released: @@ -22,7 +22,7 @@ class VersionAndName(NamedTuple): VersionAndName(Version(2, 0, 0), "v2_0_0.wav", None), VersionAndName(Version(1, 1, 1), "v1_1_1.wav", None), VersionAndName(Version(1, 0, 0), "v1.wav", None), - VersionAndName(PROTEUS_VERSION, "Proteus_Capture.wav", None), + VersionAndName(_PROTEUS_VERSION, "Proteus_Capture.wav", None), # ================================================================================== ) diff --git a/nam/train/colab.py b/nam/train/colab.py index c4e13a9..56a4cdd 100644 --- a/nam/train/colab.py +++ b/nam/train/colab.py @@ -6,14 +6,18 @@ Hide the mess in Colab to make things look pretty for users. """ -from pathlib import Path -from typing import Optional, Tuple +from pathlib import Path as _Path +from typing import Optional as _Optional, Tuple as _Tuple -from ..models.metadata import UserMetadata -from ._names import INPUT_BASENAMES, LATEST_VERSION, Version -from ._version import PROTEUS_VERSION, Version -from .core import TrainOutput, train -from .metadata import TRAINING_KEY +from ..models.metadata import UserMetadata as _UserMetadata +from ._names import ( + INPUT_BASENAMES as _INPUT_BASENAMES, + LATEST_VERSION as _LATEST_VERSION, + Version as _Version, +) +from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version as _Version +from .core import TrainOutput as _TrainOutput, train as _train +from .metadata import TRAINING_KEY as _TRAINING_KEY _BUGGY_INPUT_BASENAMES = { # 1.1.0 has the spikes at the wrong spots. @@ -23,41 +27,41 @@ _TRAIN_PATH = "." -def _check_for_files() -> Tuple[Version, str]: +def _check_for_files() -> _Tuple[_Version, str]: # TODO use hash logic as in GUI trainer! print("Checking that we have all of the required audio files...") for name in _BUGGY_INPUT_BASENAMES: - if Path(name).exists(): + if _Path(name).exists(): raise RuntimeError( - f"Detected input signal {name} that has known bugs. Please download the latest input signal, {LATEST_VERSION[1]}" + f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_LATEST_VERSION[1]}" ) - for input_version, input_basename, other_names in INPUT_BASENAMES: - if Path(input_basename).exists(): - if input_version == PROTEUS_VERSION: + for input_version, input_basename, other_names in _INPUT_BASENAMES: + if _Path(input_basename).exists(): + if input_version == _PROTEUS_VERSION: print(f"Using Proteus input file...") - elif input_version != LATEST_VERSION.version: + elif input_version != _LATEST_VERSION.version: print( f"WARNING: Using out-of-date input file {input_basename}. " "Recommend downloading and using the latest version, " - f"{LATEST_VERSION.name}." + f"{_LATEST_VERSION.name}." ) break if other_names is not None: for other_name in other_names: - if Path(other_name).exists(): + if _Path(other_name).exists(): raise RuntimeError( f"Found out-of-date input file {other_name}. Rename it to {input_basename} and re-run." ) else: raise FileNotFoundError( - f"Didn't find NAM's input audio file. Please upload {LATEST_VERSION.name}" + f"Didn't find NAM's input audio file. Please upload {_LATEST_VERSION.name}" ) # We found it - if not Path(_OUTPUT_BASENAME).exists(): + if not _Path(_OUTPUT_BASENAME).exists(): raise FileNotFoundError( f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}." ) - if input_version != PROTEUS_VERSION: + if input_version != _PROTEUS_VERSION: print(f"Found {input_basename}, version {input_version}") else: print(f"Found Proteus input {input_basename}.") @@ -66,7 +70,7 @@ def _check_for_files() -> Tuple[Version, str]: def _get_valid_export_directory(): def get_path(version): - return Path("exported_models", f"version_{version}") + return _Path("exported_models", f"version_{version}") version = 0 while get_path(version).exists(): @@ -76,13 +80,13 @@ def get_path(version): def run( epochs: int = 100, - delay: Optional[int] = None, + delay: _Optional[int] = None, model_type: str = "WaveNet", architecture: str = "standard", lr: float = 0.004, lr_decay: float = 0.007, - seed: Optional[int] = 0, - user_metadata: Optional[UserMetadata] = None, + seed: _Optional[int] = 0, + user_metadata: _Optional[_UserMetadata] = None, ignore_checks: bool = False, fit_mrstft: bool = True, ): @@ -101,7 +105,7 @@ def run( input_version, input_basename = _check_for_files() - train_output: TrainOutput = train( + train_output: _TrainOutput = _train( input_basename, _OUTPUT_BASENAME, _TRAIN_PATH, @@ -129,6 +133,6 @@ def run( model.net.export( model_export_outdir, user_metadata=user_metadata, - other_metadata={TRAINING_KEY: training_metadata.model_dump()}, + other_metadata={_TRAINING_KEY: training_metadata.model_dump()}, ) print(f"Model exported to {model_export_outdir}. Enjoy!") diff --git a/nam/train/core.py b/nam/train/core.py index d5254f5..c3afb08 100644 --- a/nam/train/core.py +++ b/nam/train/core.py @@ -8,31 +8,46 @@ Used by the GUI and Colab trainers. """ -import hashlib -import tkinter as tk -from copy import deepcopy -from enum import Enum -from functools import partial -from pathlib import Path -from time import time -from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union - -import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -import torch -from pydantic import BaseModel -from pytorch_lightning.utilities.warnings import PossibleUserWarning -from torch.utils.data import DataLoader - -from ..data import DataError, Split, init_dataset, wav_to_np, wav_to_tensor -from ..models.exportable import Exportable -from ..models.losses import esr -from ..models.metadata import UserMetadata -from ..util import filter_warnings -from ._version import PROTEUS_VERSION, Version -from .lightning_module import LightningModule -from . import metadata +import hashlib as _hashlib +import tkinter as _tk +from copy import deepcopy as _deepcopy +from enum import Enum as _Enum +from functools import partial as _partial +from pathlib import Path as _Path +from time import time as _time +from typing import ( + Dict as _Dict, + NamedTuple as _NamedTuple, + Optional as _Optional, + Sequence as _Sequence, + Tuple as _Tuple, + Union as _Union, +) + +import matplotlib.pyplot as _plt +import numpy as _np +import pytorch_lightning as _pl +import torch as _torch +from pydantic import BaseModel as _BaseModel +from pytorch_lightning.utilities.warnings import ( + PossibleUserWarning as _PossibleUserWarning, +) +from torch.utils.data import DataLoader as _DataLoader + +from ..data import ( + DataError as _DataError, + Split as _Split, + init_dataset as _init_dataset, + wav_to_np as _wav_to_np, + wav_to_tensor as _wav_to_tensor, +) +from ..models.exportable import Exportable as _Exportable +from ..models.losses import esr as _ESR +from ..models.metadata import UserMetadata as _UserMetadata +from ..util import filter_warnings as _filter_warnings +from ._version import PROTEUS_VERSION as _PROTEUS_VERSION, Version as _Version +from .lightning_module import LightningModule as _LightningModule +from . import metadata as _metadata # Training using the simplified trainers in NAM is done at 48k. STANDARD_SAMPLE_RATE = 48_000.0 @@ -40,7 +55,7 @@ _NY_DEFAULT = 8192 -class Architecture(Enum): +class Architecture(_Enum): STANDARD = "standard" LITE = "lite" FEATHER = "feather" @@ -51,17 +66,17 @@ class _InputValidationError(ValueError): pass -def _detect_input_version(input_path) -> Tuple[Version, bool]: +def _detect_input_version(input_path) -> _Tuple[_Version, bool]: """ Check to see if the input matches any of the known inputs :return: version, strong match """ - def detect_strong(input_path) -> Optional[Version]: + def detect_strong(input_path) -> _Optional[_Version]: def assign_hash(path): # Use this to create hashes for new files - md5 = hashlib.md5() + md5 = _hashlib.md5() buffer_size = 65536 with open(path, "rb") as f: while True: @@ -76,11 +91,11 @@ def assign_hash(path): print(f"Strong hash: {file_hash}") version = { - "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0), - "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1), - "ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0), - "36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0), - "80e224bd5622fd6153ff1fd9f34cb3bd": PROTEUS_VERSION, + "4d54a958861bf720ec4637f43d44a7ef": _Version(1, 0, 0), + "7c3b6119c74465f79d96c761a0e27370": _Version(1, 1, 1), + "ede3b9d82135ce10c7ace3bb27469422": _Version(2, 0, 0), + "36cd1af62985c2fac3e654333e36431e": _Version(3, 0, 0), + "80e224bd5622fd6153ff1fd9f34cb3bd": _PROTEUS_VERSION, }.get(file_hash) if version is None: print( @@ -89,17 +104,17 @@ def assign_hash(path): ) return version - def detect_weak(input_path) -> Optional[Version]: + def detect_weak(input_path) -> _Optional[_Version]: def assign_hash(path): - Hash = Optional[str] - Hashes = Tuple[Hash, Hash] + Hash = _Optional[str] + Hashes = _Tuple[Hash, Hash] - def _hash(x: np.ndarray) -> str: - return hashlib.md5(x).hexdigest() + def _hash(x: _np.ndarray) -> str: + return _hashlib.md5(x).hexdigest() def assign_hashes_v1(path) -> Hashes: # Use this to create recognized hashes for new files - x, info = wav_to_np(path, info=True) + x, info = _wav_to_np(path, info=True) rate = info.rate if rate != _V1_DATA_INFO.rate: return None, None @@ -116,7 +131,7 @@ def assign_hashes_v1(path) -> Hashes: def assign_hashes_v2(path) -> Hashes: # Use this to create recognized hashes for new files - x, info = wav_to_np(path, info=True) + x, info = _wav_to_np(path, info=True) rate = info.rate if rate != _V2_DATA_INFO.rate: return None, None @@ -133,7 +148,7 @@ def assign_hashes_v2(path) -> Hashes: def assign_hashes_v3(path) -> Hashes: # Use this to create recognized hashes for new files - x, info = wav_to_np(path, info=True) + x, info = _wav_to_np(path, info=True) rate = info.rate if rate != _V3_DATA_INFO.rate: return None, None @@ -147,7 +162,7 @@ def assign_hashes_v3(path) -> Hashes: def assign_hash_v4(path) -> Hash: # Use this to create recognized hashes for new files - x, info = wav_to_np(path, info=True) + x, info = _wav_to_np(path, info=True) rate = info.rate if rate != _V4_DATA_INFO.rate: return None @@ -195,7 +210,7 @@ def assign_hash_v4(path) -> Hash: ( "dadb5d62f6c3973a59bf01439799809b", "8458126969a3f9d8e19a53554eb1fd52", - ): Version(3, 0, 0) + ): _Version(3, 0, 0) }.get((start_hash_v3, end_hash_v3)) if version is not None: return version @@ -203,7 +218,7 @@ def assign_hash_v4(path) -> Hash: ( "1c4d94fbcb47e4d820bef611c1d4ae65", "28694e7bf9ab3f8ae6ef86e9545d4663", - ): Version(2, 0, 0) + ): _Version(2, 0, 0) }.get((start_hash_v2, end_hash_v2)) if version is not None: return version @@ -211,17 +226,17 @@ def assign_hash_v4(path) -> Hash: ( "bb4e140c9299bae67560d280917eb52b", "9b2468fcb6e9460a399fc5f64389d353", - ): Version( + ): _Version( 1, 0, 0 ), # FIXME! ( "9f20c6b5f7fef68dd88307625a573a14", "8458126969a3f9d8e19a53554eb1fd52", - ): Version(1, 1, 1), + ): _Version(1, 1, 1), }.get((start_hash_v1, end_hash_v1)) if version is not None: return version - version = {"46151c8030798081acc00a725325a07d": PROTEUS_VERSION}.get(hash_v4) + version = {"46151c8030798081acc00a725325a07d": _PROTEUS_VERSION}.get(hash_v4) return version version = detect_strong(input_path) @@ -239,20 +254,20 @@ def assign_hash_v4(path) -> Hash: return version, strong_match -class _DataInfo(BaseModel): +class _DataInfo(_BaseModel): """ :param major_version: Data major version """ major_version: int - rate: Optional[float] + rate: _Optional[float] t_blips: int first_blips_start: int t_validate: int train_start: int validation_start: int - noise_interval: Tuple[int, int] - blip_locations: Sequence[Sequence[int]] + noise_interval: _Tuple[int, int] + blip_locations: _Sequence[_Sequence[int]] _V1_DATA_INFO = _DataInfo( @@ -336,7 +351,7 @@ class _DataInfo(BaseModel): _DELAY_CALIBRATION_SAFETY_FACTOR = 1 # Might be able to make this zero... -def _warn_lookaheads(indices: Sequence[int]) -> str: +def _warn_lookaheads(indices: _Sequence[int]) -> str: return ( f"WARNING: delays from some blips ({','.join([str(i) for i in indices])}) are " "at the minimum value possible. This usually means that something is " @@ -350,7 +365,7 @@ def _calibrate_latency_v_all( abs_threshold=_DELAY_CALIBRATION_ABS_THRESHOLD, rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD, safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR, -) -> metadata.LatencyCalibration: +) -> _metadata.LatencyCalibration: """ Calibrate the delay in teh input-output pair based on blips. This only uses the blips in the first set of blip locations! @@ -359,8 +374,8 @@ def _calibrate_latency_v_all( """ def report_any_latency_warnings( - delays: Sequence[int], - ) -> metadata.LatencyCalibrationWarnings: + delays: _Sequence[int], + ) -> _metadata.LatencyCalibrationWarnings: # Warnings associated with any single delay: # "Lookahead warning": if the delay is equal to the lookahead, then it's @@ -375,7 +390,7 @@ def report_any_latency_warnings( # If they're _really_ different, then something might be wrong. max_disagreement_threshold = 20 max_disagreement_too_high = ( - np.max(delays) - np.min(delays) >= max_disagreement_threshold + _np.max(delays) - _np.min(delays) >= max_disagreement_threshold ) if max_disagreement_too_high: print( @@ -384,7 +399,7 @@ def report_any_latency_warnings( "badly, then you might need to provide the latency manually." ) - return metadata.LatencyCalibrationWarnings( + return _metadata.LatencyCalibrationWarnings( matches_lookahead=matches_lookahead, disagreement_too_high=max_disagreement_too_high, ) @@ -393,8 +408,8 @@ def report_any_latency_warnings( lookback = 10_000 # Calibrate the level for the trigger: y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips] - background_level = np.max( - np.abs( + background_level = _np.max( + _np.abs( y[ data_info.noise_interval[0] - data_info.first_blips_start : data_info.noise_interval[1] @@ -414,8 +429,8 @@ def report_any_latency_warnings( start_looking = i_rel - lookahead stop_looking = i_rel + lookback y_scans.append(y[start_looking:stop_looking]) - y_scan_average = np.mean(np.stack(y_scans), axis=0) - triggered = np.where(np.abs(y_scan_average) > trigger_threshold)[0] + y_scan_average = _np.mean(_np.stack(y_scans), axis=0) + triggered = _np.where(_np.abs(y_scan_average) > trigger_threshold)[0] if len(triggered) == 0: msg = ( "No response activated the trigger in response to input spikes. " @@ -423,24 +438,24 @@ def report_any_latency_warnings( ) print(msg) print("SHARE THIS PLOT IF YOU ASK FOR HELP") - plt.figure() - plt.plot( - np.arange(-lookahead, lookback), + _plt.figure() + _plt.plot( + _np.arange(-lookahead, lookback), y_scan_average, color="C0", label="Signal average", ) for y_scan in y_scans: - plt.plot(np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2) - plt.axvline(x=0, color="C1", linestyle="--", label="Trigger") - plt.axhline(y=-trigger_threshold, color="k", linestyle="--", label="Threshold") - plt.axhline(y=trigger_threshold, color="k", linestyle="--") - plt.xlim((-lookahead, lookback)) - plt.xlabel("Samples") - plt.ylabel("Response") - plt.legend() - plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP") - plt.show() + _plt.plot(_np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2) + _plt.axvline(x=0, color="C1", linestyle="--", label="Trigger") + _plt.axhline(y=-trigger_threshold, color="k", linestyle="--", label="Threshold") + _plt.axhline(y=trigger_threshold, color="k", linestyle="--") + _plt.xlim((-lookahead, lookback)) + _plt.xlabel("Samples") + _plt.ylabel("Response") + _plt.legend() + _plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP") + _plt.show() raise RuntimeError(msg) else: j = triggered[0] @@ -454,7 +469,7 @@ def report_any_latency_warnings( f"After aplying safety factor of {safety_factor}, the final delay is " f"{delay_post_safety_factor}" ) - return metadata.LatencyCalibration( + return _metadata.LatencyCalibration( algorithm_version=1, delays=[delay], safety_factor=safety_factor, @@ -463,72 +478,72 @@ def report_any_latency_warnings( ) -_calibrate_latency_v1 = partial(_calibrate_latency_v_all, _V1_DATA_INFO) -_calibrate_latency_v2 = partial(_calibrate_latency_v_all, _V2_DATA_INFO) -_calibrate_latency_v3 = partial(_calibrate_latency_v_all, _V3_DATA_INFO) -_calibrate_latency_v4 = partial(_calibrate_latency_v_all, _V4_DATA_INFO) +_calibrate_latency_v1 = _partial(_calibrate_latency_v_all, _V1_DATA_INFO) +_calibrate_latency_v2 = _partial(_calibrate_latency_v_all, _V2_DATA_INFO) +_calibrate_latency_v3 = _partial(_calibrate_latency_v_all, _V3_DATA_INFO) +_calibrate_latency_v4 = _partial(_calibrate_latency_v_all, _V4_DATA_INFO) def _plot_latency_v_all( data_info: _DataInfo, latency: int, input_path: str, output_path: str, _nofail=True ): print("Plotting the latency for manual inspection...") - x = wav_to_np(input_path)[ + x = _wav_to_np(input_path)[ data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips ] - y = wav_to_np(output_path)[ + y = _wav_to_np(output_path)[ data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips ] # Only get the blips we really want. - i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0] + i = _np.where(_np.abs(x) > 0.5 * _np.abs(x).max())[0] if len(i) == 0: print("Failed to find the spike in the input file.") print( "Plotting the input and output; there should be spikes at around the " "marked locations." ) - t = np.arange( + t = _np.arange( data_info.first_blips_start, data_info.first_blips_start + data_info.t_blips ) expected_spikes = data_info.blip_locations[0] # For v1 specifically - fig, axs = plt.subplots(len((x, y)), 1) + fig, axs = _plt.subplots(len((x, y)), 1) for ax, curve in zip(axs, (x, y)): ax.plot(t, curve) [ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes] - plt.show() + _plt.show() if _nofail: raise RuntimeError("Failed to plot delay") else: - plt.figure() + _plt.figure() di = 20 # V1's got not a spike but a longer plateau; take the front of it. if data_info.major_version == 1: i = [i[0]] for e, ii in enumerate(i, 1): - plt.plot( - np.arange(-di, di), + _plt.plot( + _np.arange(-di, di), y[ii - di + latency : ii + di + latency], ".-", label=f"Output {e}", ) - plt.axvline(x=0, linestyle="--", color="k") - plt.legend() - plt.show() # This doesn't freeze the notebook + _plt.axvline(x=0, linestyle="--", color="k") + _plt.legend() + _plt.show() # This doesn't freeze the notebook -_plot_latency_v1 = partial(_plot_latency_v_all, _V1_DATA_INFO) -_plot_latency_v2 = partial(_plot_latency_v_all, _V2_DATA_INFO) -_plot_latency_v3 = partial(_plot_latency_v_all, _V3_DATA_INFO) -_plot_latency_v4 = partial(_plot_latency_v_all, _V4_DATA_INFO) +_plot_latency_v1 = _partial(_plot_latency_v_all, _V1_DATA_INFO) +_plot_latency_v2 = _partial(_plot_latency_v_all, _V2_DATA_INFO) +_plot_latency_v3 = _partial(_plot_latency_v_all, _V3_DATA_INFO) +_plot_latency_v4 = _partial(_plot_latency_v_all, _V4_DATA_INFO) def _analyze_latency( - user_latency: Optional[int], - input_version: Version, + user_latency: _Optional[int], + input_version: _Version, input_path: str, output_path: str, silent: bool = False, -) -> metadata.Latency: +) -> _metadata.Latency: """ :param is_proteus: Forget the version; d """ @@ -546,14 +561,14 @@ def _analyze_latency( ) if user_latency is not None: print(f"Delay is specified as {user_latency}") - calibration_output = calibrate(wav_to_np(output_path)) + calibration_output = calibrate(_wav_to_np(output_path)) latency = ( user_latency if user_latency is not None else calibration_output.recommended ) if not silent: plot(latency, input_path, output_path) - return metadata.Latency(manual=user_latency, calibration=calibration_output) + return _metadata.Latency(manual=user_latency, calibration=calibration_output) def get_lstm_config(architecture): @@ -585,8 +600,8 @@ def get_lstm_config(architecture): }[architecture] -def _check_v1(*args, **kwargs) -> metadata.DataChecks: - return metadata.DataChecks(version=1, passed=True) +def _check_v1(*args, **kwargs) -> _metadata.DataChecks: + return _metadata.DataChecks(version=1, passed=True) def _esr_validation_replicate_msg(threshold: float) -> str: @@ -601,16 +616,18 @@ def _esr_validation_replicate_msg(threshold: float) -> str: ) -def _check_v2(input_path, output_path, delay: int, silent: bool) -> metadata.DataChecks: - with torch.no_grad(): +def _check_v2( + input_path, output_path, delay: int, silent: bool +) -> _metadata.DataChecks: + with _torch.no_grad(): print("V2 checks...") rate = _V2_DATA_INFO.rate - y = wav_to_tensor(output_path, rate=rate) + y = _wav_to_tensor(output_path, rate=rate) t_blips = _V2_DATA_INFO.t_blips t_validate = _V2_DATA_INFO.t_validate y_val_1 = y[-(t_blips + 2 * t_validate) : -(t_blips + t_validate)] y_val_2 = y[-(t_blips + t_validate) : -t_blips] - esr_replicate = esr(y_val_1, y_val_2).item() + esr_replicate = _ESR(y_val_1, y_val_2).item() print(f"Replicate ESR is {esr_replicate:.8f}.") esr_replicate_threshold = 0.01 if esr_replicate > esr_replicate_threshold: @@ -630,19 +647,19 @@ def get_blips(y): i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)] start = -10 end = 1000 - blips = torch.stack( + blips = _torch.stack( [ - torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]), - torch.stack([y[j0 + start : j0 + end], y[j1 + start : j1 + end]]), + _torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]), + _torch.stack([y[j0 + start : j0 + end], y[j1 + start : j1 + end]]), ] ) return blips blips = get_blips(y) - esr_0 = esr(blips[0][0], blips[0][1]).item() # Within start - esr_1 = esr(blips[1][0], blips[1][1]).item() # Within end - esr_cross_0 = esr(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end - esr_cross_1 = esr(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end + esr_0 = _ESR(blips[0][0], blips[0][1]).item() # Within start + esr_1 = _ESR(blips[1][0], blips[1][1]).item() # Within end + esr_cross_0 = _ESR(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end + esr_cross_1 = _ESR(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end print(" ESRs:") print(f" Start : {esr_0}") @@ -655,22 +672,22 @@ def get_blips(y): def plot_esr_blip_error( show_plot: bool, msg: str, - arrays: Sequence[Sequence[float]], - labels: Sequence[str], + arrays: _Sequence[_Sequence[float]], + labels: _Sequence[str], ): """ :param silent: Whether to make and show a plot about it """ if show_plot: - plt.figure() - [plt.plot(array, label=label) for array, label in zip(arrays, labels)] - plt.xlabel("Sample") - plt.ylabel("Output") - plt.legend() - plt.grid() + _plt.figure() + [_plt.plot(array, label=label) for array, label in zip(arrays, labels)] + _plt.xlabel("Sample") + _plt.ylabel("Output") + _plt.legend() + _plt.grid() print(msg) if show_plot: - plt.show() + _plt.show() print( "This is known to be a very sensitive test, so training will continue. " "If the model doesn't look good, then this may be why!" @@ -693,7 +710,7 @@ def plot_esr_blip_error( blip_pair, ("Replicate 1", "Replicate 2"), ) - return metadata.DataChecks(version=2, passed=False) + return _metadata.DataChecks(version=2, passed=False) # Check blips between start & end of train signal for e, blip_pair, replicate in zip( (esr_cross_0, esr_cross_1), blips.permute(1, 0, 2), (1, 2) @@ -707,46 +724,46 @@ def plot_esr_blip_error( blip_pair, (f"Start, replicate {replicate}", f"End, replicate {replicate}"), ) - return metadata.DataChecks(version=2, passed=False) - return metadata.DataChecks(version=2, passed=True) + return _metadata.DataChecks(version=2, passed=False) + return _metadata.DataChecks(version=2, passed=True) def _check_v3( input_path, output_path, silent: bool, *args, **kwargs -) -> metadata.DataChecks: - with torch.no_grad(): +) -> _metadata.DataChecks: + with _torch.no_grad(): print("V3 checks...") rate = _V3_DATA_INFO.rate - y = wav_to_tensor(output_path, rate=rate) - n = len(wav_to_tensor(input_path)) # to End-crop output + y = _wav_to_tensor(output_path, rate=rate) + n = len(_wav_to_tensor(input_path)) # to End-crop output y_val_1 = y[: _V3_DATA_INFO.t_validate] y_val_2 = y[n - _V3_DATA_INFO.t_validate : n] - esr_replicate = esr(y_val_1, y_val_2).item() + esr_replicate = _ESR(y_val_1, y_val_2).item() print(f"Replicate ESR is {esr_replicate:.8f}.") esr_replicate_threshold = 0.01 if esr_replicate > esr_replicate_threshold: print(_esr_validation_replicate_msg(esr_replicate_threshold)) if not silent: - plt.figure() - t = np.arange(len(y_val_1)) / rate - plt.plot(t, y_val_1, label="Validation 1") - plt.plot(t, y_val_2, label="Validation 2") - plt.xlabel("Time (sec)") - plt.legend() - plt.title("V3 check: Validation replicate FAILURE") - plt.show() - return metadata.DataChecks(version=3, passed=False) - return metadata.DataChecks(version=3, passed=True) + _plt.figure() + t = _np.arange(len(y_val_1)) / rate + _plt.plot(t, y_val_1, label="Validation 1") + _plt.plot(t, y_val_2, label="Validation 2") + _plt.xlabel("Time (sec)") + _plt.legend() + _plt.title("V3 check: Validation replicate FAILURE") + _plt.show() + return _metadata.DataChecks(version=3, passed=False) + return _metadata.DataChecks(version=3, passed=True) def _check_v4( input_path, output_path, silent: bool, *args, **kwargs -) -> metadata.DataChecks: +) -> _metadata.DataChecks: # Things we can't check: # Latency compensation agreement # Data replicability print("Using Proteus audio file. Standard data checks aren't possible!") - signal, info = wav_to_np(output_path, info=True) + signal, info = _wav_to_np(output_path, info=True) passed = True if info.rate != _V4_DATA_INFO.rate: print( @@ -761,12 +778,12 @@ def _check_v4( "File doesn't meet the minimum length requirements for latency compensation and validation signal!" ) passed = False - return metadata.DataChecks(version=4, passed=passed) + return _metadata.DataChecks(version=4, passed=passed) def _check_data( - input_path: str, output_path: str, input_version: Version, delay: int, silent: bool -) -> Optional[metadata.DataChecks]: + input_path: str, output_path: str, input_version: _Version, delay: int, silent: bool +) -> _Optional[_metadata.DataChecks]: """ Ensure that everything should go smoothly @@ -912,7 +929,11 @@ def get_wavenet_config(architecture): def _get_data_config( - input_version: Version, input_path: Path, output_path: Path, ny: int, latency: int + input_version: _Version, + input_path: _Path, + output_path: _Path, + ny: int, + latency: int, ) -> dict: def get_split_kwargs(data_info: _DataInfo): if data_info.major_version == 1: @@ -976,7 +997,7 @@ def get_split_kwargs(data_info: _DataInfo): def _get_configs( - input_version: Version, + input_version: _Version, input_path: str, output_path: str, latency: int, @@ -1031,9 +1052,9 @@ def _get_configs( model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF - if torch.cuda.is_available(): + if _torch.cuda.is_available(): device_config = {"accelerator": "gpu", "devices": 1} - elif torch.backends.mps.is_available(): + elif _torch.backends.mps.is_available(): device_config = {"accelerator": "mps", "devices": 1} else: print("WARNING: No GPU was found. Training will be very slow!") @@ -1053,45 +1074,49 @@ def _get_configs( def _get_dataloaders( - data_config: Dict, learning_config: Dict, model: LightningModule -) -> Tuple[DataLoader, DataLoader]: - data_config, learning_config = [deepcopy(c) for c in (data_config, learning_config)] + data_config: _Dict, learning_config: _Dict, model: _LightningModule +) -> _Tuple[_DataLoader, _DataLoader]: + data_config, learning_config = [ + _deepcopy(c) for c in (data_config, learning_config) + ] data_config["common"]["nx"] = model.net.receptive_field - dataset_train = init_dataset(data_config, Split.TRAIN) - dataset_validation = init_dataset(data_config, Split.VALIDATION) - train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) - val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) + dataset_train = _init_dataset(data_config, _Split.TRAIN) + dataset_validation = _init_dataset(data_config, _Split.VALIDATION) + train_dataloader = _DataLoader(dataset_train, **learning_config["train_dataloader"]) + val_dataloader = _DataLoader( + dataset_validation, **learning_config["val_dataloader"] + ) return train_dataloader, val_dataloader -def _esr(pred: torch.Tensor, target: torch.Tensor) -> float: +def _esr(pred: _torch.Tensor, target: _torch.Tensor) -> float: return ( - torch.mean(torch.square(pred - target)).item() - / torch.mean(torch.square(target)).item() + _torch.mean(_torch.square(pred - target)).item() + / _torch.mean(_torch.square(target)).item() ) def _plot( model, ds, - window_start: Optional[int] = None, - window_end: Optional[int] = None, - filepath: Optional[str] = None, + window_start: _Optional[int] = None, + window_end: _Optional[int] = None, + filepath: _Optional[str] = None, silent: bool = False, ) -> float: """ :return: The ESR """ print("Plotting a comparison of your model with the target output...") - with torch.no_grad(): + with _torch.no_grad(): tx = len(ds.x) / 48_000 print(f"Run (t={tx:.2f} sec)") - t0 = time() + t0 = _time() output = model(ds.x).flatten().cpu().numpy() - t1 = time() + t1 = _time() print(f"Took {t1 - t0:.2f} sec ({tx / (t1 - t0):.2f}x)") - esr = _esr(torch.Tensor(output), ds.y) + esr = _esr(_torch.Tensor(output), ds.y) # Trying my best to put numbers to it... if esr < 0.01: esr_comment = "Great!" @@ -1106,15 +1131,15 @@ def _plot( print(f"Error-signal ratio = {esr:.4g}") print(esr_comment) - plt.figure(figsize=(16, 5)) - plt.plot(output[window_start:window_end], label="Prediction") - plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") - plt.title(f"ESR={esr:.4g}") - plt.legend() + _plt.figure(figsize=(16, 5)) + _plt.plot(output[window_start:window_end], label="Prediction") + _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") + _plt.title(f"ESR={esr:.4g}") + _plt.legend() if filepath is not None: - plt.savefig(filepath + ".png") + _plt.savefig(filepath + ".png") if not silent: - plt.show() + _plt.show() return esr @@ -1139,14 +1164,14 @@ def _print_nasty_checks_warning(): def _nasty_checks_modal(): msg = "You are ignoring the checks!\nYour model might turn out bad!" - root = tk.Tk() + root = _tk.Tk() root.withdraw() # hide the root window - modal = tk.Toplevel(root) + modal = _tk.Toplevel(root) modal.geometry("300x100") modal.title("Warning!") - label = tk.Label(modal, text=msg) + label = _tk.Label(modal, text=msg) label.pack(pady=10) - ok_button = tk.Button( + ok_button = _tk.Button( modal, text="I can only blame myself!", command=lambda: [modal.destroy(), root.quit()], @@ -1156,7 +1181,7 @@ def _nasty_checks_modal(): modal.mainloop() -class _ValidationStopping(pl.callbacks.EarlyStopping): +class _ValidationStopping(_pl.callbacks.EarlyStopping): """ Callback to indicate to stop training if the validation metric is good enough, without the other conditions that EarlyStopping usually forces like patience. @@ -1164,10 +1189,10 @@ class _ValidationStopping(pl.callbacks.EarlyStopping): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.patience = np.inf + self.patience = _np.inf -class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): +class _ModelCheckpoint(_pl.callbacks.model_checkpoint.ModelCheckpoint): """ Extension to model checkpoint to save a .nam file as well as the .ckpt file. """ @@ -1175,9 +1200,9 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): def __init__( self, *args, - user_metadata: Optional[UserMetadata] = None, - settings_metadata: Optional[metadata.Settings] = None, - data_metadata: Optional[metadata.Data] = None, + user_metadata: _Optional[_UserMetadata] = None, + settings_metadata: _Optional[_metadata.Settings] = None, + data_metadata: _Optional[_metadata.Data] = None, **kwargs, ): super().__init__(*args, **kwargs) @@ -1185,10 +1210,10 @@ def __init__( self._settings_metadata = settings_metadata self._data_metadata = data_metadata - _NAM_FILE_EXTENSION = Exportable.FILE_EXTENSION + _NAM_FILE_EXTENSION = _Exportable.FILE_EXTENSION @classmethod - def _get_nam_filepath(cls, filepath: str) -> Path: + def _get_nam_filepath(cls, filepath: str) -> _Path: """ Given a .ckpt filepath, figure out a .nam for it. """ @@ -1197,18 +1222,18 @@ def _get_nam_filepath(cls, filepath: str) -> Path: f"Checkpoint filepath {filepath} doesn't end in expected extension " f"{cls.FILE_EXTENSION}" ) - return Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION) + return _Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION) @property def _include_other_metadata(self) -> bool: return self._settings_metadata is not None and self._data_metadata is not None - def _save_checkpoint(self, trainer: pl.Trainer, filepath: str): + def _save_checkpoint(self, trainer: _pl.Trainer, filepath: str): # Save the .ckpt: super()._save_checkpoint(trainer, filepath) # Save the .nam: nam_filepath = self._get_nam_filepath(filepath) - pl_model: LightningModule = trainer.model + pl_model: _LightningModule = trainer.model nam_model = pl_model.net outdir = nam_filepath.parent # HACK: Assume the extension @@ -1217,7 +1242,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, filepath: str): None if not self._include_other_metadata else { - metadata.TRAINING_KEY: metadata.TrainingMetadata( + _metadata.TRAINING_KEY: _metadata.TrainingMetadata( settings=self._settings_metadata, data=self._data_metadata, validation_esr=None, # TODO how to get this? @@ -1231,7 +1256,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, filepath: str): other_metadata=other_metadata, ) - def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + def _remove_checkpoint(self, trainer: _pl.Trainer, filepath: str) -> None: super()._remove_checkpoint(trainer, filepath) nam_path = self._get_nam_filepath(filepath) if nam_path.exists(): @@ -1239,10 +1264,10 @@ def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: def get_callbacks( - threshold_esr: Optional[float], - user_metadata: Optional[UserMetadata] = None, - settings_metadata: Optional[metadata.Settings] = None, - data_metadata: Optional[metadata.Data] = None, + threshold_esr: _Optional[float], + user_metadata: _Optional[_UserMetadata] = None, + settings_metadata: _Optional[_metadata.Settings] = None, + data_metadata: _Optional[_metadata.Data] = None, ): callbacks = [ _ModelCheckpoint( @@ -1269,18 +1294,18 @@ def get_callbacks( return callbacks -class TrainOutput(NamedTuple): +class TrainOutput(_NamedTuple): """ :param model: The trained model :param simpliifed_trianer_metadata: The metadata summarizing training with the simplified trainer. """ - model: Optional[LightningModule] - metadata: metadata.TrainingMetadata + model: _Optional[_LightningModule] + metadata: _metadata.TrainingMetadata -def _get_final_latency(latency_analysis: metadata.Latency) -> int: +def _get_final_latency(latency_analysis: _metadata.Latency) -> int: if latency_analysis.manual is not None: latency = latency_analysis.manual print(f"Latency provided as {latency_analysis.manual}; override calibration") @@ -1294,27 +1319,27 @@ def train( input_path: str, output_path: str, train_path: str, - input_version: Optional[Version] = None, # Deprecate? + input_version: _Optional[_Version] = None, # Deprecate? epochs=100, - delay: Optional[int] = None, - latency: Optional[int] = None, + delay: _Optional[int] = None, + latency: _Optional[int] = None, model_type: str = "WaveNet", - architecture: Union[Architecture, str] = Architecture.STANDARD, + architecture: _Union[Architecture, str] = Architecture.STANDARD, batch_size: int = 16, ny: int = _NY_DEFAULT, lr=0.004, lr_decay=0.007, - seed: Optional[int] = 0, + seed: _Optional[int] = 0, save_plot: bool = False, silent: bool = False, modelname: str = "model", ignore_checks: bool = False, local: bool = False, fit_mrstft: bool = True, - threshold_esr: Optional[bool] = None, - user_metadata: Optional[UserMetadata] = None, - fast_dev_run: Union[bool, int] = False, -) -> Optional[TrainOutput]: + threshold_esr: _Optional[bool] = None, + user_metadata: _Optional[_UserMetadata] = None, + fast_dev_run: _Union[bool, int] = False, +) -> _Optional[TrainOutput]: """ :param lr_decay: =1-gamma for Exponential learning rate decay. :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`. @@ -1322,8 +1347,8 @@ def train( """ def parse_user_latency( - delay: Optional[int], latency: Optional[int] - ) -> Optional[int]: + delay: _Optional[int], latency: _Optional[int] + ) -> _Optional[int]: if delay is not None: if latency is not None: raise ValueError("Both delay and latency are provided; use latency!") @@ -1332,7 +1357,7 @@ def parse_user_latency( return latency if seed is not None: - torch.manual_seed(seed) + _torch.manual_seed(seed) # HACK: We need to check the sample rates and lengths of the audio here or else # It will look like a bad self-ESR (Issue 473) @@ -1384,9 +1409,9 @@ def parse_user_latency( print("Exiting core training...") return TrainOutput( model=None, - metadata=metadata.TrainingMetadata( - settings=metadata.Settings(ignore_checks=ignore_checks), - data=metadata.Data( + metadata=_metadata.TrainingMetadata( + settings=_metadata.Settings(ignore_checks=ignore_checks), + data=_metadata.Data( latency=latency_analysis, checks=data_check_output ), validation_esr=None, @@ -1417,7 +1442,7 @@ def parse_user_latency( # * Model is re-instantiated after training anyways. # (Hacky) solution: set sample rate in model from dataloader after second # instantiation from final checkpoint. - model = LightningModule.init_from_config(model_config) + model = _LightningModule.init_from_config(model_config) train_dataloader, val_dataloader = _get_dataloaders( data_config, learning_config, model ) @@ -1431,10 +1456,10 @@ def parse_user_latency( model.net.sample_rate = sample_rate # Put together the metadata that's needed in checkpoints: - settings_metadata = metadata.Settings(ignore_checks=ignore_checks) - data_metadata = metadata.Data(latency=latency_analysis, checks=data_check_output) + settings_metadata = _metadata.Settings(ignore_checks=ignore_checks) + data_metadata = _metadata.Data(latency=latency_analysis, checks=data_check_output) - trainer = pl.Trainer( + trainer = _pl.Trainer( callbacks=get_callbacks( threshold_esr, user_metadata=user_metadata, @@ -1446,21 +1471,21 @@ def parse_user_latency( **learning_config["trainer"], ) # Suppress the PossibleUserWarning about num_workers (Issue 345) - with filter_warnings("ignore", category=PossibleUserWarning): + with _filter_warnings("ignore", category=_PossibleUserWarning): trainer.fit(model, train_dataloader, val_dataloader) # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path if best_checkpoint != "": - model = LightningModule.load_from_checkpoint( + model = _LightningModule.load_from_checkpoint( trainer.checkpoint_callback.best_model_path, - **LightningModule.parse_config(model_config), + **_LightningModule.parse_config(model_config), ) model.cpu() model.eval() model.net.sample_rate = sample_rate # Hack, part 2 - def window_kwargs(version: Version): + def window_kwargs(version: _Version): if version.major == 1: return dict( window_start=100_000, # Start of the plotting window, in samples @@ -1487,7 +1512,7 @@ def window_kwargs(version: Version): ) return TrainOutput( model=model, - metadata=metadata.TrainingMetadata( + metadata=_metadata.TrainingMetadata( settings=settings_metadata, data=data_metadata, validation_esr=validation_esr, @@ -1495,7 +1520,7 @@ def window_kwargs(version: Version): ) -class DataInputValidation(BaseModel): +class DataInputValidation(_BaseModel): passed: bool @@ -1512,49 +1537,49 @@ def validate_input(input_path) -> DataInputValidation: return DataInputValidation(passed=False) -class _PyTorchDataSplitValidation(BaseModel): +class _PyTorchDataSplitValidation(_BaseModel): """ :param msg: On exception, catch and assign. Otherwise None """ passed: bool - msg: Optional[str] + msg: _Optional[str] -class _PyTorchDataValidation(BaseModel): +class _PyTorchDataValidation(_BaseModel): passed: bool train: _PyTorchDataSplitValidation # cf Split.TRAIN validation: _PyTorchDataSplitValidation # Split.VALIDATION -class _SampleRateValidation(BaseModel): +class _SampleRateValidation(_BaseModel): passed: bool input: int output: int -class _LengthValidation(BaseModel): +class _LengthValidation(_BaseModel): passed: bool delta_seconds: float -class DataValidationOutput(BaseModel): +class DataValidationOutput(_BaseModel): passed: bool passed_critical: bool sample_rate: _SampleRateValidation length: _LengthValidation input_version: str - latency: metadata.Latency - checks: metadata.DataChecks + latency: _metadata.Latency + checks: _metadata.DataChecks pytorch: _PyTorchDataValidation def _check_audio_sample_rates( - input_path: Path, - output_path: Path, + input_path: _Path, + output_path: _Path, ) -> _SampleRateValidation: - _, x_info = wav_to_np(input_path, info=True) - _, y_info = wav_to_np(output_path, info=True) + _, x_info = _wav_to_np(input_path, info=True) + _, y_info = _wav_to_np(output_path, info=True) return _SampleRateValidation( passed=x_info.rate == y_info.rate, @@ -1564,10 +1589,10 @@ def _check_audio_sample_rates( def _check_audio_lengths( - input_path: Path, - output_path: Path, - max_under_seconds: Optional[float] = 0.0, - max_over_seconds: Optional[float] = 1.0, + input_path: _Path, + output_path: _Path, + max_under_seconds: _Optional[float] = 0.0, + max_over_seconds: _Optional[float] = 1.0, ) -> _LengthValidation: """ Check that the input and output have the right lengths compared to each @@ -1584,8 +1609,8 @@ def _check_audio_lengths( value of 1.0 means that the output can't be more than a second longer than the input. """ - x, x_info = wav_to_np(input_path, info=True) - y, y_info = wav_to_np(output_path, info=True) + x, x_info = _wav_to_np(input_path, info=True) + y, y_info = _wav_to_np(output_path, info=True) length_input = len(x) / x_info.rate length_output = len(y) / y_info.rate @@ -1601,9 +1626,9 @@ def _check_audio_lengths( def validate_data( - input_path: Path, - output_path: Path, - user_latency: Optional[int], + input_path: _Path, + output_path: _Path, + user_latency: _Optional[int], num_output_samples_per_datum: int = _NY_DEFAULT, ): """ @@ -1660,14 +1685,14 @@ def validate_data( # be unlikely to make a difference. Still, would be nice to fix. data_config["common"]["nx"] = 4096 - pytorch_data_split_validation_dict: Dict[str, _PyTorchDataSplitValidation] = {} - for split in Split: + pytorch_data_split_validation_dict: _Dict[str, _PyTorchDataSplitValidation] = {} + for split in _Split: try: - init_dataset(data_config, split) + _init_dataset(data_config, split) pytorch_data_split_validation_dict[split.value] = ( _PyTorchDataSplitValidation(passed=True, msg=None) ) - except DataError as e: + except _DataError as e: pytorch_data_split_validation_dict[split.value] = ( _PyTorchDataSplitValidation(passed=False, msg=str(e)) ) diff --git a/nam/train/full.py b/nam/train/full.py index 57410ee..09f5c18 100644 --- a/nam/train/full.py +++ b/nam/train/full.py @@ -2,31 +2,37 @@ # Created Date: Tuesday March 26th 2024 # Author: Enrico Schifano (eraz1997@live.it) -import json -from pathlib import Path -from time import time -from typing import Optional, Union -from warnings import warn - -import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -from pytorch_lightning.utilities.warnings import PossibleUserWarning -import torch -from torch.utils.data import DataLoader - -from nam.data import ConcatDataset, Split, init_dataset -from nam.train.lightning_module import LightningModule -from nam.util import filter_warnings - -torch.manual_seed(0) - - -def _rms(x: Union[np.ndarray, torch.Tensor]) -> float: - if isinstance(x, np.ndarray): - return np.sqrt(np.mean(np.square(x))) - elif isinstance(x, torch.Tensor): - return torch.sqrt(torch.mean(torch.square(x))).item() +import json as _json +from pathlib import Path as _Path +from time import time as _time +from typing import Optional as _Optional, Union as _Union +from warnings import warn as _warn + +import matplotlib.pyplot as _plt +import numpy as _np +import pytorch_lightning as _pl +from pytorch_lightning.utilities.warnings import ( + PossibleUserWarning as _PossibleUserWarning, +) +import torch as _torch +from torch.utils.data import DataLoader as _DataLoader + +from nam.data import ( + ConcatDataset as _ConcatDataset, + Split as _Split, + init_dataset as _init_dataset, +) +from nam.train.lightning_module import LightningModule as _LightningModule +from nam.util import filter_warnings as _filter_warnings + +_torch.manual_seed(0) + + +def _rms(x: _Union[_np.ndarray, _torch.Tensor]) -> float: + if isinstance(x, _np.ndarray): + return _np.sqrt(_np.mean(_np.square(x))) + elif isinstance(x, _torch.Tensor): + return _torch.sqrt(_torch.mean(_torch.square(x))).item() else: raise TypeError(type(x)) @@ -36,18 +42,18 @@ def _plot( ds, savefig=None, show=True, - window_start: Optional[int] = None, - window_end: Optional[int] = None, + window_start: _Optional[int] = None, + window_end: _Optional[int] = None, ): - if isinstance(ds, ConcatDataset): + if isinstance(ds, _ConcatDataset): def extend_savefig(i, savefig): if savefig is None: return None - savefig = Path(savefig) + savefig = _Path(savefig) extension = savefig.name.split(".")[-1] stem = savefig.name[: -len(extension) - 1] - return Path(savefig.parent, f"{stem}_{i}.{extension}") + return _Path(savefig.parent, f"{stem}_{i}.{extension}") for i, ds_i in enumerate(ds.datasets): _plot( @@ -59,29 +65,29 @@ def extend_savefig(i, savefig): window_end=window_end, ) return - with torch.no_grad(): + with _torch.no_grad(): tx = len(ds.x) / 48_000 print(f"Run (t={tx:.2f})") - t0 = time() + t0 = _time() output = model(ds.x).flatten().cpu().numpy() - t1 = time() + t1 = _time() try: rt = f"{tx / (t1 - t0):.2f}" except ZeroDivisionError as e: rt = "???" print(f"Took {t1 - t0:.2f} ({rt}x)") - plt.figure(figsize=(16, 5)) - plt.plot(output[window_start:window_end], label="Prediction") - plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") - nrmse = _rms(torch.Tensor(output) - ds.y) / _rms(ds.y) + _plt.figure(figsize=(16, 5)) + _plt.plot(output[window_start:window_end], label="Prediction") + _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") + nrmse = _rms(_torch.Tensor(output) - ds.y) / _rms(ds.y) esr = nrmse**2 - plt.title(f"ESR={esr:.3f}") - plt.legend() + _plt.title(f"ESR={esr:.3f}") + _plt.legend() if savefig is not None: - plt.savefig(savefig) + _plt.savefig(savefig) if show: - plt.show() + _plt.show() def _create_callbacks(learning_config): @@ -102,7 +108,7 @@ def _create_callbacks(learning_config): ) } - checkpoint_best = pl.callbacks.model_checkpoint.ModelCheckpoint( + checkpoint_best = _pl.callbacks.model_checkpoint.ModelCheckpoint( filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}", save_top_k=3, monitor="val_loss", @@ -111,14 +117,14 @@ def _create_callbacks(learning_config): # return [checkpoint_best, checkpoint_last] # The last epoch that was finished. - checkpoint_epoch = pl.callbacks.model_checkpoint.ModelCheckpoint( + checkpoint_epoch = _pl.callbacks.model_checkpoint.ModelCheckpoint( filename="checkpoint_epoch_{epoch:04d}", every_n_epochs=1 ) if not validate_inside_epoch: return [checkpoint_best, checkpoint_epoch] else: # The last validation pass, whether at the end of an epoch or not - checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint( + checkpoint_last = _pl.callbacks.model_checkpoint.ModelCheckpoint( filename="checkpoint_last_{epoch:04d}_{step}", **kwargs ) return [checkpoint_best, checkpoint_last, checkpoint_epoch] @@ -128,7 +134,7 @@ def main( data_config, model_config, learning_config, - outdir: Path, + outdir: _Path, no_show: bool = False, make_plots=True, ): @@ -140,35 +146,37 @@ def main( ("model", model_config), ("learning", learning_config), ): - with open(Path(outdir, f"config_{basename}.json"), "w") as fp: - json.dump(config, fp, indent=4) + with open(_Path(outdir, f"config_{basename}.json"), "w") as fp: + _json.dump(config, fp, indent=4) - model = LightningModule.init_from_config(model_config) + model = _LightningModule.init_from_config(model_config) # Add receptive field to data config: data_config["common"] = data_config.get("common", {}) if "nx" in data_config["common"]: - warn( + _warn( f"Overriding data nx={data_config['common']['nx']} with model requried {model.net.receptive_field}" ) data_config["common"]["nx"] = model.net.receptive_field - dataset_train = init_dataset(data_config, Split.TRAIN) - dataset_validation = init_dataset(data_config, Split.VALIDATION) + dataset_train = _init_dataset(data_config, _Split.TRAIN) + dataset_validation = _init_dataset(data_config, _Split.VALIDATION) if dataset_train.sample_rate != dataset_validation.sample_rate: raise RuntimeError( "Train and validation data loaders have different data set sample rates: " f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}" ) model.net.sample_rate = dataset_train.sample_rate - train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) - val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) + train_dataloader = _DataLoader(dataset_train, **learning_config["train_dataloader"]) + val_dataloader = _DataLoader( + dataset_validation, **learning_config["val_dataloader"] + ) - trainer = pl.Trainer( + trainer = _pl.Trainer( callbacks=_create_callbacks(learning_config), default_root_dir=outdir, **learning_config["trainer"], ) - with filter_warnings("ignore", category=PossibleUserWarning): + with _filter_warnings("ignore", category=_PossibleUserWarning): trainer.fit( model, train_dataloader, @@ -178,9 +186,9 @@ def main( # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path if best_checkpoint != "": - model = LightningModule.load_from_checkpoint( + model = _LightningModule.load_from_checkpoint( trainer.checkpoint_callback.best_model_path, - **LightningModule.parse_config(model_config), + **_LightningModule.parse_config(model_config), ) model.cpu() model.eval() @@ -188,7 +196,7 @@ def main( _plot( model, dataset_validation, - savefig=Path(outdir, "comparison.png"), + savefig=_Path(outdir, "comparison.png"), window_start=100_000, window_end=110_000, show=False, diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py index 33a49b6..057324e 100644 --- a/nam/train/gui/__init__.py +++ b/nam/train/gui/__init__.py @@ -10,16 +10,16 @@ >>> run() """ -import abc -import re -import requests -import tkinter as tk -import subprocess -import sys -import webbrowser -from dataclasses import dataclass -from enum import Enum -from functools import partial +import abc as _abc +import re as _re +import requests as _requests +import tkinter as _tk +import subprocess as _subprocess +import sys as _sys +import webbrowser as _webbrowser +from dataclasses import dataclass as _dataclass +from enum import Enum as _Enum +from functools import partial as _partial try: # Not supported in Colab from idlelib.tooltip import Hovertip @@ -34,26 +34,43 @@ def __init__(self, *args, **kwargs): pass -from pathlib import Path -from tkinter import filedialog -from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence +from pathlib import Path as _Path +from tkinter import filedialog as _filedialog +from typing import ( + Any as _Any, + Callable as _Callable, + Dict as _Dict, + NamedTuple as _NamedTuple, + Optional as _Optional, + Sequence as _Sequence, +) try: # 3rd-party and 1st-party imports - import torch + import torch as _torch from nam import __version__ - from nam.data import Split - from nam.train import core - from nam.train.gui._resources import settings - from nam.models.metadata import GearType, UserMetadata, ToneType + from nam.data import Split as _Split + from nam.train import core as _core + from nam.train.gui._resources import settings as _settings + from nam.models.metadata import ( + GearType as _GearType, + UserMetadata as _UserMetadata, + ToneType as _ToneType, + ) # Ok private access here--this is technically allowed access - from nam.train import metadata - from nam.train._names import INPUT_BASENAMES, LATEST_VERSION - from nam.train._version import Version, get_current_version + from nam.train import metadata as _metadata + from nam.train._names import ( + INPUT_BASENAMES as _INPUT_BASENAMES, + LATEST_VERSION as _LATEST_VERSION, + ) + from nam.train._version import ( + Version as _Version, + get_current_version as _get_current_version, + ) _install_is_valid = True - _HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available() + _HAVE_ACCELERATOR = _torch.cuda.is_available() or _torch.backends.mps.is_available() except ImportError: _install_is_valid = False _HAVE_ACCELERATOR = False @@ -81,13 +98,13 @@ def __init__(self, *args, **kwargs): def _is_mac() -> bool: - return sys.platform == "darwin" + return _sys.platform == "darwin" _SYSTEM_TEXT_COLOR = "systemTextColor" if _is_mac() else "black" -@dataclass +@_dataclass class AdvancedOptions(object): """ :param architecture: Which architecture to use. @@ -99,14 +116,14 @@ class AdvancedOptions(object): stop. """ - architecture: core.Architecture + architecture: _core.Architecture num_epochs: int - latency: Optional[int] + latency: _Optional[int] ignore_checks: bool - threshold_esr: Optional[float] + threshold_esr: _Optional[float] -class _PathType(Enum): +class _PathType(_Enum): FILE = "file" DIRECTORY = "directory" MULTIFILE = "multifile" @@ -119,42 +136,42 @@ class _PathButton(object): def __init__( self, - frame: tk.Frame, + frame: _tk.Frame, button_text: str, info_str: str, path_type: _PathType, - path_key: settings.PathKey, - hooks: Optional[Sequence[Callable[[], None]]] = None, + path_key: _settings.PathKey, + hooks: _Optional[_Sequence[_Callable[[], None]]] = None, color_when_not_set: str = "#EF0000", # Darker red color_when_set: str = _SYSTEM_TEXT_COLOR, - default: Optional[Path] = None, + default: _Optional[_Path] = None, ): """ :param hooks: Callables run at the end of setting the value. """ self._button_text = button_text self._info_str = info_str - self._path: Optional[Path] = default + self._path: _Optional[_Path] = default self._path_type = path_type self._path_key = path_key self._frame = frame self._widgets = {} - self._widgets["button"] = tk.Button( + self._widgets["button"] = _tk.Button( self._frame, text=button_text, width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, command=self._set_val, ) - self._widgets["button"].pack(side=tk.LEFT) - self._widgets["label"] = tk.Label( + self._widgets["button"].pack(side=_tk.LEFT) + self._widgets["label"] = _tk.Label( self._frame, width=_TEXT_WIDTH, height=_BUTTON_HEIGHT, bg=None, anchor="w", ) - self._widgets["label"].pack(side=tk.LEFT) + self._widgets["label"].pack(side=_tk.LEFT) self._hooks = hooks self._color_when_not_set = color_when_not_set self._color_when_set = color_when_set @@ -173,7 +190,7 @@ def __setitem__(self, key, val): ) @property - def val(self) -> Optional[Path]: + def val(self) -> _Optional[_Path]: return self._path def _set_text(self): @@ -189,7 +206,7 @@ def _set_text(self): ] = f"{self._button_text.capitalize()} set to {val}" def _set_val(self): - last_path = settings.get_last_path(self._path_key) + last_path = _settings.get_last_path(self._path_key) if last_path is None: initial_dir = None elif not last_path.is_dir(): @@ -197,15 +214,15 @@ def _set_val(self): else: initial_dir = last_path result = { - _PathType.FILE: filedialog.askopenfilename, - _PathType.DIRECTORY: filedialog.askdirectory, - _PathType.MULTIFILE: filedialog.askopenfilenames, + _PathType.FILE: _filedialog.askopenfilename, + _PathType.DIRECTORY: _filedialog.askdirectory, + _PathType.MULTIFILE: _filedialog.askopenfilenames, }[self._path_type](initialdir=str(initial_dir)) if result != "": self._path = result - settings.set_last_path( + _settings.set_last_path( self._path_key, - Path(result[0] if self._path_type == _PathType.MULTIFILE else result), + _Path(result[0] if self._path_type == _PathType.MULTIFILE else result), ) self._set_text() @@ -218,14 +235,14 @@ class _InputPathButton(_PathButton): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Download the training file! - self._widgets["button_download_input"] = tk.Button( + self._widgets["button_download_input"] = _tk.Button( self._frame, text="Download input file", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, command=self._download_input_file, ) - self._widgets["button_download_input"].pack(side=tk.RIGHT) + self._widgets["button_download_input"].pack(side=_tk.RIGHT) @classmethod def _download_input_file(cls): @@ -237,20 +254,20 @@ def _download_input_file(cls): "v1.wav": "", } # Pick the most recent file. - for input_basename in INPUT_BASENAMES: + for input_basename in _INPUT_BASENAMES: name = input_basename.name url = file_urls.get(name) if url: - if name != LATEST_VERSION.name: + if name != _LATEST_VERSION.name: print( f"WARNING: File {name} is out of date. " "This needs to be updated!" ) - webbrowser.open(url) + _webbrowser.open(url) return -class _CheckboxKeys(Enum): +class _CheckboxKeys(_Enum): """ Keys for checkboxes """ @@ -259,13 +276,13 @@ class _CheckboxKeys(Enum): SAVE_PLOT = "save_plot" -class _TopLevelWithOk(tk.Toplevel): +class _TopLevelWithOk(_tk.Toplevel): """ Toplevel with an Ok button (provide yourself!) """ def __init__( - self, on_ok: Callable[[None], None], resume_main: Callable[[None], None] + self, on_ok: _Callable[[None], None], resume_main: _Callable[[None], None] ): """ :param on_ok: What to do when "Ok" button is pressed @@ -281,17 +298,17 @@ def destroy(self, pressed_ok: bool = False): super().destroy() -class _TopLevelWithYesNo(tk.Toplevel): +class _TopLevelWithYesNo(_tk.Toplevel): """ Toplevel holding functions for yes/no buttons to close """ def __init__( self, - on_yes: Callable[[None], None], - on_no: Callable[[None], None], - on_close: Optional[Callable[[None], None]], - resume_main: Callable[[None], None], + on_yes: _Callable[[None], None], + on_no: _Callable[[None], None], + on_close: _Optional[_Callable[[None], None]], + resume_main: _Callable[[None], None], ): """ :param on_yes: What to do when "Yes" button is pressed. @@ -321,13 +338,13 @@ class _OkModal(object): Message and OK button """ - def __init__(self, resume_main, msg: str, label_kwargs: Optional[dict] = None): + def __init__(self, resume_main, msg: str, label_kwargs: _Optional[dict] = None): label_kwargs = {} if label_kwargs is None else label_kwargs self._root = _TopLevelWithOk((lambda: None), resume_main) - self._text = tk.Label(self._root, text=msg, **label_kwargs) + self._text = _tk.Label(self._root, text=msg, **label_kwargs) self._text.pack() - self._ok = tk.Button( + self._ok = _tk.Button( self._root, text="Ok", width=_BUTTON_WIDTH, @@ -344,38 +361,38 @@ class _YesNoModal(object): def __init__( self, - on_yes: Callable[[None], None], - on_no: Callable[[None], None], + on_yes: _Callable[[None], None], + on_no: _Callable[[None], None], resume_main, msg: str, - on_close: Optional[Callable[[None], None]] = None, - label_kwargs: Optional[dict] = None, + on_close: _Optional[_Callable[[None], None]] = None, + label_kwargs: _Optional[dict] = None, ): label_kwargs = {} if label_kwargs is None else label_kwargs self._root = _TopLevelWithYesNo(on_yes, on_no, on_close, resume_main) - self._text = tk.Label(self._root, text=msg, **label_kwargs) + self._text = _tk.Label(self._root, text=msg, **label_kwargs) self._text.pack() - self._buttons_frame = tk.Frame(self._root) + self._buttons_frame = _tk.Frame(self._root) self._buttons_frame.pack() - self._yes = tk.Button( + self._yes = _tk.Button( self._buttons_frame, text="Yes", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, command=lambda: self._root.destroy(pressed_yes=True), ) - self._yes.pack(side=tk.LEFT) - self._no = tk.Button( + self._yes.pack(side=_tk.LEFT) + self._no = _tk.Button( self._buttons_frame, text="No", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, command=lambda: self._root.destroy(pressed_no=True), ) - self._no.pack(side=tk.RIGHT) + self._no.pack(side=_tk.RIGHT) -class _GUIWidgets(Enum): +class _GUIWidgets(_Enum): INPUT_PATH = "input_path" OUTPUT_PATH = "output_path" TRAINING_DESTINATION = "training_destination" @@ -385,57 +402,57 @@ class _GUIWidgets(Enum): UPDATE = "update" -@dataclass +@_dataclass class Checkbox(object): - variable: tk.BooleanVar - check_button: tk.Checkbutton + variable: _tk.BooleanVar + check_button: _tk.Checkbutton class GUI(object): def __init__(self): - self._root = tk.Tk() + self._root = _tk.Tk() self._root.title(f"NAM Trainer - v{__version__}") self._widgets = {} # Buttons for paths: - self._frame_input = tk.Frame(self._root) + self._frame_input = _tk.Frame(self._root) self._frame_input.pack(anchor="w") self._widgets[_GUIWidgets.INPUT_PATH] = _InputPathButton( self._frame_input, "Input Audio", - f"Select input (DI) file (e.g. {LATEST_VERSION.name})", + f"Select input (DI) file (e.g. {_LATEST_VERSION.name})", _PathType.FILE, - settings.PathKey.INPUT_FILE, + _settings.PathKey.INPUT_FILE, hooks=[self._check_button_states], ) - self._frame_output_path = tk.Frame(self._root) + self._frame_output_path = _tk.Frame(self._root) self._frame_output_path.pack(anchor="w") self._widgets[_GUIWidgets.OUTPUT_PATH] = _PathButton( self._frame_output_path, "Output Audio", "Select output (reamped) file - (Choose MULTIPLE FILES to enable BATCH TRAINING)", _PathType.MULTIFILE, - settings.PathKey.OUTPUT_FILE, + _settings.PathKey.OUTPUT_FILE, hooks=[self._check_button_states], ) - self._frame_train_destination = tk.Frame(self._root) + self._frame_train_destination = _tk.Frame(self._root) self._frame_train_destination.pack(anchor="w") self._widgets[_GUIWidgets.TRAINING_DESTINATION] = _PathButton( self._frame_train_destination, "Train Destination", "Select training output directory", _PathType.DIRECTORY, - settings.PathKey.TRAINING_DESTINATION, + _settings.PathKey.TRAINING_DESTINATION, hooks=[self._check_button_states], ) # Metadata - self.user_metadata = UserMetadata() - self._frame_metadata = tk.Frame(self._root) + self.user_metadata = _UserMetadata() + self._frame_metadata = _tk.Frame(self._root) self._frame_metadata.pack(anchor="w") - self._widgets["metadata"] = tk.Button( + self._widgets["metadata"] = _tk.Button( self._frame_metadata, text="Metadata...", width=_BUTTON_WIDTH, @@ -449,16 +466,16 @@ def __init__(self): self._get_additional_options_frame() # Last frames: avdanced options & train in the SE corner: - self._frame_advanced_options = tk.Frame(self._root) - self._frame_train = tk.Frame(self._root) - self._frame_update = tk.Frame(self._root) + self._frame_advanced_options = _tk.Frame(self._root) + self._frame_train = _tk.Frame(self._root) + self._frame_update = _tk.Frame(self._root) # Pack must be in reverse order - self._frame_update.pack(side=tk.BOTTOM, anchor="e") - self._frame_train.pack(side=tk.BOTTOM, anchor="e") - self._frame_advanced_options.pack(side=tk.BOTTOM, anchor="e") + self._frame_update.pack(side=_tk.BOTTOM, anchor="e") + self._frame_train.pack(side=_tk.BOTTOM, anchor="e") + self._frame_advanced_options.pack(side=_tk.BOTTOM, anchor="e") # Advanced options for training - default_architecture = core.Architecture.STANDARD + default_architecture = _core.Architecture.STANDARD self.advanced_options = AdvancedOptions( default_architecture, _DEFAULT_NUM_EPOCHS, @@ -468,7 +485,7 @@ def __init__(self): ) # Window to edit them: - self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = tk.Button( + self._widgets[_GUIWidgets.ADVANCED_OPTIONS] = _tk.Button( self._frame_advanced_options, text="Advanced options...", width=_BUTTON_WIDTH, @@ -479,7 +496,7 @@ def __init__(self): # Train button - self._widgets[_GUIWidgets.TRAIN] = tk.Button( + self._widgets[_GUIWidgets.TRAIN] = _tk.Button( self._frame_train, text="Train", width=_BUTTON_WIDTH, @@ -492,7 +509,7 @@ def __init__(self): self._check_button_states() - def core_train_kwargs(self) -> Dict[str, Any]: + def core_train_kwargs(self) -> _Dict[str, _Any]: """ Get any additional kwargs to provide to `core.train` """ @@ -528,29 +545,29 @@ def _check_button_states(self): self._widgets[_GUIWidgets.TRAINING_DESTINATION], ) ): - self._widgets[_GUIWidgets.TRAIN]["state"] = tk.DISABLED + self._widgets[_GUIWidgets.TRAIN]["state"] = _tk.DISABLED return - self._widgets[_GUIWidgets.TRAIN]["state"] = tk.NORMAL + self._widgets[_GUIWidgets.TRAIN]["state"] = _tk.NORMAL def _get_additional_options_frame(self): # Checkboxes # TODO get these definitions into __init__() - self._frame_checkboxes = tk.Frame(self._root) - self._frame_checkboxes.pack(side=tk.LEFT) + self._frame_checkboxes = _tk.Frame(self._root) + self._frame_checkboxes.pack(side=_tk.LEFT) row = 1 def make_checkbox( key: _CheckboxKeys, text: str, default_value: bool ) -> Checkbox: - variable = tk.BooleanVar() + variable = _tk.BooleanVar() variable.set(default_value) - check_button = tk.Checkbutton( + check_button = _tk.Checkbutton( self._frame_checkboxes, text=text, variable=variable ) self._checkboxes[key] = Checkbox(variable, check_button) self._widgets[key] = check_button # For tracking in set-all-widgets ops - self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict() + self._checkboxes: _Dict[_CheckboxKeys, Checkbox] = dict() make_checkbox( _CheckboxKeys.SILENT_TRAINING, "Silent run (suggested for batch training)", @@ -568,7 +585,7 @@ def mainloop(self): self._root.mainloop() def _disable(self): - self._set_all_widget_states_to(tk.DISABLED) + self._set_all_widget_states_to(_tk.DISABLED) def _open_advanced_options(self): """ @@ -584,15 +601,15 @@ def _open_metadata(self): self._wait_while_func(lambda resume: UserMetadataGUI(resume, self)) - def _pack_update_button(self, version_from: Version, version_to: Version): + def _pack_update_button(self, version_from: _Version, version_to: _Version): """ Pack a button that a user can click to update """ def update_nam(): - result = subprocess.run( + result = _subprocess.run( [ - f"{sys.executable}", + f"{_sys.executable}", "-m", "pip", "install", @@ -611,7 +628,7 @@ def update_nam(): "Update failed! See logs.", ) - self._widgets[_GUIWidgets.UPDATE] = tk.Button( + self._widgets[_GUIWidgets.UPDATE] = _tk.Button( self._frame_update, text=f"Update ({str(version_from)} -> {str(version_to)})", width=_BUTTON_WIDTH, @@ -621,18 +638,18 @@ def update_nam(): self._widgets[_GUIWidgets.UPDATE].pack() def _pack_update_button_if_update_is_available(self): - class UpdateInfo(NamedTuple): + class UpdateInfo(_NamedTuple): available: bool - current_version: Version - new_version: Optional[Version] + current_version: _Version + new_version: _Optional[_Version] def get_info() -> UpdateInfo: # TODO error handling url = f"https://api.github.com/repos/sdatkinson/neural-amp-modeler/releases" - current_version = get_current_version() + current_version = _get_current_version() try: - response = requests.get(url) - except requests.exceptions.ConnectionError: + response = _requests.get(url) + except _requests.exceptions.ConnectionError: print("WARNING: Failed to reach the server to check for updates") return UpdateInfo( available=False, current_version=current_version, new_version=None @@ -651,7 +668,7 @@ def get_info() -> UpdateInfo: if not tag.startswith("v"): print(f"Found invalid version {tag}") else: - this_version = Version.from_string(tag[1:]) + this_version = _Version.from_string(tag[1:]) if latest_version is None or this_version > latest_version: latest_version = this_version else: @@ -672,7 +689,7 @@ def get_info() -> UpdateInfo: ) def _resume(self): - self._set_all_widget_states_to(tk.NORMAL) + self._set_all_widget_states_to(_tk.NORMAL) self._check_button_states() def _set_all_widget_states_to(self, state): @@ -700,12 +717,12 @@ def _train2(self, ignore_checks=False): # Run it for file in file_list: print(f"Now training {file}") - basename = re.sub(r"\.wav$", "", file.split("/")[-1]) + basename = _re.sub(r"\.wav$", "", file.split("/")[-1]) user_metadata = ( - self.user_metadata if self.user_metadata_flag else UserMetadata() + self.user_metadata if self.user_metadata_flag else _UserMetadata() ) - train_output = core.train( + train_output = _core.train( input_path, file, self._widgets[_GUIWidgets.TRAINING_DESTINATION].val, @@ -735,7 +752,7 @@ def _train2(self, ignore_checks=False): basename=basename, user_metadata=user_metadata, other_metadata={ - metadata.TRAINING_KEY: train_output.metadata.model_dump() + _metadata.TRAINING_KEY: train_output.metadata.model_dump() }, ) print("Done!") @@ -745,7 +762,7 @@ def _train2(self, ignore_checks=False): self.user_metadata_flag = False def _validate_all_data( - self, input_path: Path, output_paths: Sequence[Path] + self, input_path: _Path, output_paths: _Sequence[_Path] ) -> bool: """ Validate all the data. @@ -757,14 +774,14 @@ def _validate_all_data( """ def make_message_for_file( - output_path: str, validation_output: core.DataValidationOutput + output_path: str, validation_output: _core.DataValidationOutput ) -> str: """ State the file and explain what's wrong with it. """ # TODO put this closer to what it looks at, i.e. core.DataValidationOutput msg = ( - f"\t{Path(output_path).name}:\n" # They all have the same directory so + f"\t{_Path(output_path).name}:\n" # They all have the same directory so ) if not validation_output.sample_rate.passed: msg += ( @@ -798,14 +815,14 @@ def make_message_for_file( msg += "\t\t* A data check failed (TODO in more detail).\n" if not validation_output.pytorch.passed: msg += "\t\t* PyTorch data set errors:\n" - for split in Split: + for split in _Split: split_validation = getattr(validation_output.pytorch, split.value) if not split_validation.passed: msg += f" * {split.value:10s}: {split_validation.msg}\n" return msg # Validate input - input_validation = core.validate_input(input_path) + input_validation = _core.validate_input(input_path) if not input_validation.passed: self._wait_while_func( (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)), @@ -816,7 +833,7 @@ def make_message_for_file( user_latency = self.advanced_options.latency file_validation_outputs = { - output_path: core.validate_data( + output_path: _core.validate_data( input_path, output_path, user_latency, @@ -920,12 +937,12 @@ def _rstripped_str(val): return str(val).rstrip() -class _SettingWidget(abc.ABC): +class _SettingWidget(_abc.ABC): """ A widget for the user to interact with to set something """ - @abc.abstractmethod + @_abc.abstractmethod def get(self): pass @@ -936,7 +953,11 @@ class LabeledOptionMenu(_SettingWidget): """ def __init__( - self, frame: tk.Frame, label: str, choices: Enum, default: Optional[Enum] = None + self, + frame: _tk.Frame, + label: str, + choices: _Enum, + default: _Optional[_Enum] = None, ): """ :param command: Called to propagate option selection. Is provided with the @@ -946,7 +967,7 @@ def __init__( self._choices = choices height = _BUTTON_HEIGHT bg = None - self._label = tk.Label( + self._label = _tk.Label( frame, width=_ADVANCED_OPTIONS_LEFT_WIDTH, height=height, @@ -954,26 +975,26 @@ def __init__( anchor="w", text=label, ) - self._label.pack(side=tk.LEFT) + self._label.pack(side=_tk.LEFT) - frame_menu = tk.Frame(frame) - frame_menu.pack(side=tk.RIGHT) + frame_menu = _tk.Frame(frame) + frame_menu.pack(side=_tk.RIGHT) self._selected_value = None default = (list(choices)[0] if default is None else default).value - self._menu = tk.OptionMenu( + self._menu = _tk.OptionMenu( frame_menu, - tk.StringVar(master=frame, value=default, name=label), + _tk.StringVar(master=frame, value=default, name=label), # default, *[choice.value for choice in choices], # if choice.value!=default], command=self._set, ) self._menu.config(width=_ADVANCED_OPTIONS_RIGHT_WIDTH) - self._menu.pack(side=tk.RIGHT) + self._menu.pack(side=_tk.RIGHT) # Initialize self._set(default) - def get(self) -> Enum: + def get(self) -> _Enum: return self._selected_value def _set(self, val: str): @@ -992,12 +1013,12 @@ class _Hovertip(Hovertip): def showcontents(self): # Override - label = tk.Label( + label = _tk.Label( self.tipwindow, text=self.text, - justify=tk.LEFT, + justify=_tk.LEFT, background="#ffffe0", - relief=tk.SOLID, + relief=_tk.SOLID, borderwidth=1, fg="black", ) @@ -1011,7 +1032,7 @@ class LabeledText(_SettingWidget): def __init__( self, - frame: tk.Frame, + frame: _tk.Frame, label: str, default=None, type=None, @@ -1028,7 +1049,7 @@ def __init__( self._frame = frame label_height = 2 text_height = 1 - self._label = tk.Label( + self._label = _tk.Label( frame, width=left_width, height=label_height, @@ -1036,15 +1057,15 @@ def __init__( anchor="e", text=label, ) - self._label.pack(side=tk.LEFT) + self._label.pack(side=_tk.LEFT) - self._text = tk.Text( + self._text = _tk.Text( frame, width=right_width, height=text_height, bg=None, ) - self._text.pack(side=tk.RIGHT) + self._text.pack(side=_tk.RIGHT) self._type = (lambda x: x) if type is None else type @@ -1052,10 +1073,10 @@ def __init__( self._text.insert("1.0", str(default)) # You can assign a tooltip for the label if you'd like. - self.label_tooltip: Optional[_Hovertip] = None + self.label_tooltip: _Optional[_Hovertip] = None @property - def label(self) -> tk.Label: + def label(self) -> _tk.Label: return self._label def get(self): @@ -1064,7 +1085,7 @@ def get(self): May throw a tk.TclError indicating something went wrong getting the value. """ # "1.0" means Line 1, character zero (wat) - return self._type(self._text.get("1.0", tk.END)) + return self._type(self._text.get("1.0", _tk.END)) class AdvancedOptionsGUI(object): @@ -1080,9 +1101,9 @@ def __init__(self, resume_main, parent: GUI): self.pack() # "Ok": apply and destroy - self._frame_ok = tk.Frame(self._root) + self._frame_ok = _tk.Frame(self._root) self._frame_ok.pack() - self._button_ok = tk.Button( + self._button_ok = _tk.Button( self._frame_ok, text="Ok", width=_BUTTON_WIDTH, @@ -1113,17 +1134,17 @@ def pack(self): # easier to work with. # Architecture: radio buttons - self._frame_architecture = tk.Frame(self._root) + self._frame_architecture = _tk.Frame(self._root) self._frame_architecture.pack() self._architecture = LabeledOptionMenu( self._frame_architecture, "Architecture", - core.Architecture, + _core.Architecture, default=self._parent.advanced_options.architecture, ) # Number of epochs: text box - self._frame_epochs = tk.Frame(self._root) + self._frame_epochs = _tk.Frame(self._root) self._frame_epochs.pack() self._num_epochs = LabeledText( @@ -1134,7 +1155,7 @@ def pack(self): ) # Delay: text box - self._frame_latency = tk.Frame(self._root) + self._frame_latency = _tk.Frame(self._root) self._frame_latency.pack() self._latency = LabeledText( @@ -1145,7 +1166,7 @@ def pack(self): ) # Threshold ESR - self._frame_threshold_esr = tk.Frame(self._root) + self._frame_threshold_esr = _tk.Frame(self._root) self._frame_threshold_esr.pack() self._threshold_esr = LabeledText( self._frame_threshold_esr, @@ -1168,9 +1189,9 @@ def __init__(self, resume_main, parent: GUI): self.pack() # "Ok": apply and destroy - self._frame_ok = tk.Frame(self._root) + self._frame_ok = _tk.Frame(self._root) self._frame_ok.pack() - self._button_ok = tk.Button( + self._button_ok = _tk.Button( self._frame_ok, text="Ok", width=_BUTTON_WIDTH, @@ -1210,7 +1231,7 @@ def pack(self): # TODO things that are `_SettingWidget`s are named carefully, need to make this # easier to work with. - LabeledText_ = partial( + LabeledText_ = _partial( LabeledText, left_width=_METADATA_LEFT_WIDTH, right_width=_METADATA_RIGHT_WIDTH, @@ -1218,7 +1239,7 @@ def pack(self): parent = self._parent # Name - self._frame_name = tk.Frame(self._root) + self._frame_name = _tk.Frame(self._root) self._frame_name.pack() self._name = LabeledText_( self._frame_name, @@ -1227,7 +1248,7 @@ def pack(self): type=_rstripped_str, ) # Modeled by - self._frame_modeled_by = tk.Frame(self._root) + self._frame_modeled_by = _tk.Frame(self._root) self._frame_modeled_by.pack() self._modeled_by = LabeledText_( self._frame_modeled_by, @@ -1236,7 +1257,7 @@ def pack(self): type=_rstripped_str, ) # Gear make - self._frame_gear_make = tk.Frame(self._root) + self._frame_gear_make = _tk.Frame(self._root) self._frame_gear_make.pack() self._gear_make = LabeledText_( self._frame_gear_make, @@ -1245,7 +1266,7 @@ def pack(self): type=_rstripped_str, ) # Gear model - self._frame_gear_model = tk.Frame(self._root) + self._frame_gear_model = _tk.Frame(self._root) self._frame_gear_model.pack() self._gear_model = LabeledText_( self._frame_gear_model, @@ -1254,7 +1275,7 @@ def pack(self): type=_rstripped_str, ) # Calibration: input & output dBu - self._frame_input_dbu = tk.Frame(self._root) + self._frame_input_dbu = _tk.Frame(self._root) self._frame_input_dbu.pack() self._input_level_dbu = LabeledText_( self._frame_input_dbu, @@ -1272,7 +1293,7 @@ def pack(self): "Record the value here." ), ) - self._frame_output_dbu = tk.Frame(self._root) + self._frame_output_dbu = _tk.Frame(self._root) self._frame_output_dbu.pack() self._output_level_dbu = LabeledText_( self._frame_output_dbu, @@ -1293,36 +1314,36 @@ def pack(self): ), ) # Gear type - self._frame_gear_type = tk.Frame(self._root) + self._frame_gear_type = _tk.Frame(self._root) self._frame_gear_type.pack() self._gear_type = LabeledOptionMenu( self._frame_gear_type, "Gear type", - GearType, + _GearType, default=parent.user_metadata.gear_type, ) # Tone type - self._frame_tone_type = tk.Frame(self._root) + self._frame_tone_type = _tk.Frame(self._root) self._frame_tone_type.pack() self._tone_type = LabeledOptionMenu( self._frame_tone_type, "Tone type", - ToneType, + _ToneType, default=parent.user_metadata.tone_type, ) def _install_error(): - window = tk.Tk() + window = _tk.Tk() window.title("ERROR") - label = tk.Label( + label = _tk.Label( window, width=45, height=2, text="The NAM training software has not been installed correctly.", ) label.pack() - button = tk.Button(window, width=10, height=2, text="Quit", command=window.destroy) + button = _tk.Button(window, width=10, height=2, text="Quit", command=window.destroy) button.pack() window.mainloop() diff --git a/nam/train/lightning_module.py b/nam/train/lightning_module.py index a915a5c..8e7bd72 100644 --- a/nam/train/lightning_module.py +++ b/nam/train/lightning_module.py @@ -10,32 +10,38 @@ For the base *PyTorch* model containing the actual architecture, see `..models.base`. """ -from dataclasses import dataclass -from enum import Enum -from typing import Any, Dict, NamedTuple, Optional, Tuple - -import auraloss -import logging -import pytorch_lightning as pl -import torch -import torch.nn as nn - -from .._core import InitializableFromConfig -from ..models.conv_net import ConvNet -from ..models.linear import Linear +from dataclasses import dataclass as _dataclass +from enum import Enum as _Enum +from typing import ( + Any as _Any, + Dict as _Dict, + NamedTuple as _NamedTuple, + Optional as _Optional, + Tuple as _Tuple, +) + +import auraloss as _auraloss +import logging as _logging +import pytorch_lightning as _pl +import torch as _torch +import torch.nn as _nn + +from .._core import InitializableFromConfig as _InitializableFromConfig +from ..models.conv_net import ConvNet as _ConvNet +from ..models.linear import Linear as _Linear from ..models.losses import ( - apply_pre_emphasis_filter, - esr, - multi_resolution_stft_loss, - mse_fft, + apply_pre_emphasis_filter as _apply_pre_emphasis_filter, + esr as _esr, + multi_resolution_stft_loss as _multi_resolution_stft_loss, + mse_fft as _mse_fft, ) -from ..models.recurrent import LSTM -from ..models.wavenet import WaveNet +from ..models.recurrent import LSTM as _LSTM +from ..models.wavenet import WaveNet as _WaveNet -logger = logging.getLogger(__name__) +logger = _logging.getLogger(__name__) -class ValidationLoss(Enum): +class ValidationLoss(_Enum): """ mse: mean squared error esr: error signal ratio (Eq. (10) from @@ -51,8 +57,8 @@ class ValidationLoss(Enum): ESR = "esr" -@dataclass -class LossConfig(InitializableFromConfig): +@_dataclass +class LossConfig(_InitializableFromConfig): """ :param mrstft_weight: Multi-resolution short-time Fourier transform loss coefficient. None means to skip; 2e-4 works pretty well if one wants to use it. @@ -64,15 +70,15 @@ class LossConfig(InitializableFromConfig): :param pre_ """ - mrstft_weight: Optional[float] = None + mrstft_weight: _Optional[float] = None fourier: bool = False mask_first: int = 0 dc_weight: float = None val_loss: ValidationLoss = ValidationLoss.MSE - pre_emph_weight: Optional[float] = None - pre_emph_coef: Optional[float] = None - pre_emph_mrstft_weight: Optional[float] = None - pre_emph_mrstft_coef: Optional[float] = None + pre_emph_weight: _Optional[float] = None + pre_emph_coef: _Optional[float] = None + pre_emph_mrstft_weight: _Optional[float] = None + pre_emph_mrstft_coef: _Optional[float] = None @classmethod def parse_config(cls, config): @@ -97,7 +103,7 @@ def apply_mask(self, *args): return tuple(a[..., self.mask_first :] for a in args) @classmethod - def _get_mrstft_weight(cls, config) -> Optional[float]: + def _get_mrstft_weight(cls, config) -> _Optional[float]: key = "mrstft_weight" wrong_key = "mstft_key" # Backward compatibility if key in config: @@ -117,20 +123,20 @@ def _get_mrstft_weight(cls, config) -> Optional[float]: return None -class _LossItem(NamedTuple): - weight: Optional[float] - value: Optional[torch.Tensor] +class _LossItem(_NamedTuple): + weight: _Optional[float] + value: _Optional[_torch.Tensor] _model_net_init_registry = { - "ConvNet": ConvNet.init_from_config, - "Linear": Linear.init_from_config, - "LSTM": LSTM.init_from_config, - "WaveNet": WaveNet.init_from_config, + "ConvNet": _ConvNet.init_from_config, + "Linear": _Linear.init_from_config, + "LSTM": _LSTM.init_from_config, + "WaveNet": _WaveNet.init_from_config, } -class LightningModule(pl.LightningModule, InitializableFromConfig): +class LightningModule(_pl.LightningModule, _InitializableFromConfig): """ The PyTorch Lightning Module that unites the model with its loss and optimization recipe. @@ -139,9 +145,9 @@ class LightningModule(pl.LightningModule, InitializableFromConfig): def __init__( self, net, - optimizer_config: Optional[dict] = None, - scheduler_config: Optional[dict] = None, - loss_config: Optional[LossConfig] = None, + optimizer_config: _Optional[dict] = None, + scheduler_config: _Optional[dict] = None, + loss_config: _Optional[LossConfig] = None, ): """ :param scheduler_config: contains @@ -162,7 +168,7 @@ def __init__( # Where to compute the MRSTFT. # Keeping it on-device is preferable, but if that fails, then remember to drop # it to cpu from then on. - self._mrstft_device: Optional[torch.device] = None + self._mrstft_device: _Optional[_torch.device] = None @classmethod def init_from_config(cls, config): @@ -223,16 +229,16 @@ def register_net_initializer(cls, name, constructor, overwrite: bool = False): _model_net_init_registry[name] = constructor @property - def net(self) -> nn.Module: + def net(self) -> _nn.Module: return self._net def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config) + optimizer = _torch.optim.Adam(self.parameters(), **self._optimizer_config) if self._scheduler_config is None: return optimizer else: lr_scheduler = getattr( - torch.optim.lr_scheduler, self._scheduler_config["class"] + _torch.optim.lr_scheduler, self._scheduler_config["class"] )(optimizer, **self._scheduler_config["kwargs"]) lr_scheduler_config = {"scheduler": lr_scheduler} for key in ("interval", "frequency", "monitor"): @@ -243,17 +249,17 @@ def configure_optimizers(self): def forward(self, *args, **kwargs): return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead. - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: _Dict[str, _Any]) -> None: # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351 self.net.sample_rate = checkpoint["sample_rate"] - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_save_checkpoint(self, checkpoint: _Dict[str, _Any]) -> None: # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351 checkpoint["sample_rate"] = self.net.sample_rate def _shared_step( self, batch - ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]: + ) -> _Tuple[_torch.Tensor, _torch.Tensor, _Dict[str, _LossItem]]: """ B: Batch size L: Sequence length @@ -267,7 +273,7 @@ def _shared_step( loss_dict = {} # Mind keys versus validation loss requested... # Prediction aka MSE loss if self._loss_config.fourier: - loss_dict["MSE_FFT"] = _LossItem(1.0, mse_fft(preds, targets)) + loss_dict["MSE_FFT"] = _LossItem(1.0, _mse_fft(preds, targets)) else: loss_dict["MSE"] = _LossItem(1.0, self._mse_loss(preds, targets)) # Pre-emphasized MSE @@ -300,8 +306,8 @@ def _shared_step( if dc_weight is not None and dc_weight > 0.0: # Denominator could be a bad idea. I'm going to omit it esp since I'm # using mini batches - mean_dims = torch.arange(1, preds.ndim).tolist() - dc_loss = nn.MSELoss()( + mean_dims = _torch.arange(1, preds.ndim).tolist() + dc_loss = _nn.MSELoss()( preds.mean(dim=mean_dims), targets.mean(dim=mean_dims) ) loss_dict["DC MSE"] = _LossItem(dc_weight, dc_loss) @@ -344,7 +350,7 @@ def get_val_loss(): ) return val_loss - def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def _esr_loss(self, preds: _torch.Tensor, targets: _torch.Tensor) -> _torch.Tensor: """ Error signal ratio aka ESR loss. @@ -358,21 +364,21 @@ def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: :param targets: (B,L) :return: () """ - return esr(preds, targets) + return _esr(preds, targets) - def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None): + def _mse_loss(self, preds, targets, pre_emph_coef: _Optional[float] = None): if pre_emph_coef is not None: preds, targets = [ - apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) + _apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) ] - return nn.MSELoss()(preds, targets) + return _nn.MSELoss()(preds, targets) def _mrstft_loss( self, - preds: torch.Tensor, - targets: torch.Tensor, - pre_emph_coef: Optional[float] = None, - ) -> torch.Tensor: + preds: _torch.Tensor, + targets: _torch.Tensor, + pre_emph_coef: _Optional[float] = None, + ) -> _torch.Tensor: """ Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation. B: Batch size @@ -383,16 +389,16 @@ def _mrstft_loss( :return: () """ if self._mrstft is None: - self._mrstft = auraloss.freq.MultiResolutionSTFTLoss() + self._mrstft = _auraloss.freq.MultiResolutionSTFTLoss() backup_device = "cpu" if pre_emph_coef is not None: preds, targets = [ - apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) + _apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) ] try: - return multi_resolution_stft_loss( + return _multi_resolution_stft_loss( preds, targets, self._mrstft, device=self._mrstft_device ) except Exception as e: @@ -400,6 +406,6 @@ def _mrstft_loss( raise e logger.warning("MRSTFT failed on device; falling back to CPU") self._mrstft_device = backup_device - return multi_resolution_stft_loss( + return _multi_resolution_stft_loss( preds, targets, self._mrstft, device=self._mrstft_device ) diff --git a/nam/train/metadata.py b/nam/train/metadata.py index d5a1c25..630bd52 100644 --- a/nam/train/metadata.py +++ b/nam/train/metadata.py @@ -9,15 +9,15 @@ # This isn't part of ../metadata because it's not necessarily worth knowning about--only # if you're using the simplified trainers! -from typing import List, Optional +from typing import List as _List, Optional as _Optional -from pydantic import BaseModel +from pydantic import BaseModel as _BaseModel # The key under which the metadata are saved in the .nam: TRAINING_KEY = "training" -class Settings(BaseModel): +class Settings(_BaseModel): """ User-provided settings """ @@ -25,7 +25,7 @@ class Settings(BaseModel): ignore_checks: bool -class LatencyCalibrationWarnings(BaseModel): +class LatencyCalibrationWarnings(_BaseModel): """ Things that aren't necessarily wrong with the latency calibration but are worth looking into. @@ -42,34 +42,34 @@ class LatencyCalibrationWarnings(BaseModel): disagreement_too_high: bool -class LatencyCalibration(BaseModel): +class LatencyCalibration(_BaseModel): algorithm_version: int - delays: List[int] + delays: _List[int] safety_factor: int recommended: int warnings: LatencyCalibrationWarnings -class Latency(BaseModel): +class Latency(_BaseModel): """ Information about the latency """ - manual: Optional[int] + manual: _Optional[int] calibration: LatencyCalibration -class DataChecks(BaseModel): +class DataChecks(_BaseModel): version: int passed: bool -class Data(BaseModel): +class Data(_BaseModel): latency: Latency checks: DataChecks -class TrainingMetadata(BaseModel): +class TrainingMetadata(_BaseModel): settings: Settings data: Data - validation_esr: Optional[float] + validation_esr: _Optional[float] diff --git a/nam/util.py b/nam/util.py index c9e8197..55b873d 100644 --- a/nam/util.py +++ b/nam/util.py @@ -6,12 +6,12 @@ Helpful utilities """ -import warnings -from datetime import datetime +import warnings as _warnings +from datetime import datetime as _datetime def timestamp() -> str: - t = datetime.now() + t = _datetime.now() return f"{t.year:04d}-{t.month:02d}-{t.day:02d}-{t.hour:02d}-{t.minute:02d}-{t.second:02d}" @@ -28,10 +28,10 @@ def __init__(self, *args, **kwargs): self._kwargs = kwargs def __enter__(self): - warnings.filterwarnings(*self._args, **self._kwargs) + _warnings.filterwarnings(*self._args, **self._kwargs) def __exit__(self, exc_type, exc_val, exc_tb): - warnings.resetwarnings() + _warnings.resetwarnings() def filter_warnings(*args, **kwargs): diff --git a/tests/test_nam/test_train/test_lightning_module.py b/tests/test_nam/test_train/test_lightning_module.py index bc1a3aa..903e94d 100644 --- a/tests/test_nam/test_train/test_lightning_module.py +++ b/tests/test_nam/test_train/test_lightning_module.py @@ -48,7 +48,7 @@ def mocked_loss( raise RuntimeError("Trigger fallback") return _torch.tensor(1.0) - mocker.patch("nam.train.lightning_module.multi_resolution_stft_loss", mocked_loss) + mocker.patch("nam.train.lightning_module._multi_resolution_stft_loss", mocked_loss) batch_size = 3 sequence_length = 4096