From 89ba266e4a875344be4f64f9ad370433ead7c868 Mon Sep 17 00:00:00 2001 From: maks00170 Date: Tue, 10 Oct 2023 00:17:52 +0300 Subject: [PATCH] Add source files --- .dockerignore | 9 + .github/workflows/main.yaml | 18 ++ .gitignore | 8 + Dockerfile | 16 ++ README.md | 74 ++++++- requirements-dev.txt | 3 + requirements.txt | 19 ++ run_docker.sh | 9 + separator/config/config.py | 101 +++++++++ separator/data/dataset.py | 411 ++++++++++++++++++++++++++++++++++ separator/inference.py | 233 ++++++++++++++++++++ separator/model/PM_Unet.py | 253 +++++++++++++++++++++ separator/model/STFT.py | 104 +++++++++ separator/model/modules.py | 426 ++++++++++++++++++++++++++++++++++++ separator/pl_model.py | 355 ++++++++++++++++++++++++++++++ separator/train/augment.py | 361 ++++++++++++++++++++++++++++++ separator/train/loss.py | 171 +++++++++++++++ streaming/config/config.py | 34 +++ streaming/converter.py | 288 ++++++++++++++++++++++++ streaming/runner.py | 148 +++++++++++++ streaming/tf_lite_stream.py | 103 +++++++++ 21 files changed, 3143 insertions(+), 1 deletion(-) create mode 100644 .dockerignore create mode 100644 .github/workflows/main.yaml create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 requirements-dev.txt create mode 100644 requirements.txt create mode 100644 run_docker.sh create mode 100644 separator/config/config.py create mode 100644 separator/data/dataset.py create mode 100644 separator/inference.py create mode 100644 separator/model/PM_Unet.py create mode 100644 separator/model/STFT.py create mode 100644 separator/model/modules.py create mode 100644 separator/pl_model.py create mode 100644 separator/train/augment.py create mode 100644 separator/train/loss.py create mode 100644 streaming/config/config.py create mode 100644 streaming/converter.py create mode 100644 streaming/runner.py create mode 100644 streaming/tf_lite_stream.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..7332e35 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +.github +**/__pycache__/ +separator/inference/ +streaming/weights/ +streaming/input/ +streaming/streams/ +streaming/model/ +streaming/tflite_model/ + diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml new file mode 100644 index 0000000..c874dc1 --- /dev/null +++ b/.github/workflows/main.yaml @@ -0,0 +1,18 @@ +name: Main + +on: [push, pull_request] + +jobs: + main: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: 3.10.12 + cache: "pip" + - name: "installation" + run: | + pip install -r requirements-dev.txt + - name: "black" + run: black . --check --diff --color --exclude .*/config/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e510821 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.vscode +**/__pycache__/ +separator/inference/ +streaming/weights +streaming/input/ +streaming/streams/ +streaming/model/ +streaming/tflite_model/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d95aad3 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM nvcr.io/nvidia/tensorrt:22.08-py3 + +ENV PYTHONUNBUFFERED=1 + +RUN apt-get -y update && apt-get -y upgrade +RUN apt-get install -y --no-install-recommends ffmpeg +RUN apt-get install -y python3-pip +RUN echo 'alias python=python3' >> ~/.bashrc +RUN echo 'NCCL_SOCKET_IFNAME=lo' >> ~/.bashrc + + +WORKDIR /app +COPY requirements.txt requirements.txt +RUN pip install -r requirements.txt + +ENTRYPOINT [ "bash" ] \ No newline at end of file diff --git a/README.md b/README.md index 156f157..cf1daaa 100644 --- a/README.md +++ b/README.md @@ -1 +1,73 @@ -# PM_AUDIO \ No newline at end of file +# PM-Unet: phase and magnitude aware model for music source separation + [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://d-a-yakovlev.github.io/test/) + [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OXlCZgd5KidMDZDUItOIT9ZA4IUJHXsZ?usp=sharing) + +## Navigation +1. [Structure](#structure) +2. [Docker](#docker) +3. [Training](#training) +4. [Inference](#inference) + +## Structure +- [`separator`](./separator) ‒ main source code with model and dataset implementations and code to train model. +- [`streaming`](./streaming/demo) ‒ source code inference tf-lite version model. + +## Docker +#### To set up environment with Docker + +If you don't have Docker installed, please follow the links to find installation instructions for [Ubuntu](https://docs.docker.com/desktop/install/linux-install/), [Mac](https://docs.docker.com/desktop/install/mac-install/) or [Windows](https://docs.docker.com/desktop/install/windows-install/). + +Build docker image: + + docker build -t pmunet . + +Run docker image: + + bash run_docker.sh + +## Data +Used dataset [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#musdb18-hq-uncompressed-wav). + +[![Download dataset](https://img.shields.io/badge/Download%20dataset-65c73b)](https://zenodo.org/record/3338373/files/musdb18hq.zip?download=1) + +The dataset consists of +150 full-length stereo tracks sampled at 44.1 kHz. providing a +complete audio mix and four main elements: ”vocal”, ”bass”, +”drums” and ”other” for each sample, which can be considered as a target in the context of source separation. The kit +structure offers 100 training compositions and 50 validation +compositions + +## Training +1. Configure arguments in `separator/config/config.py`. +2. `cd separator`. +3. Run `python3 separator/pl_model.py`. + +## Inference + +### Auto local +1. Configure arguments in `separator/config/config.py`. +2. `cd separator`. +3. `python3 inference.py [-IO]` + - `-I` specify path to mixture, + - `-O` output dir, both of them optional. + +By default script loads `.pt` file with weights and `sample.wav` from google drive. + +#### For example +``` +python3 inference.py -I path/to/mix -O out_dir +``` +With successful script run four audio files (`vocals.wav` and `drums.wav`, `bass.wav`, `other.wav`) will be in `out_dir`. By default in `separator/inference/output`. + +**You can download weights manually** + +Download one the .pt file below: + * [LSTM-bottleneck version](https://drive.google.com/file/d/1uhAVMvW3x-KL2T2-VkjKjn9K7dTJnoyo/view?usp=drive_link) + * [WIthout LSTM-bottleneck version](https://drive.google.com/file/d/1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7/view?usp=drive_link) + + ### Streaming + In streaming section located scripts for: convert model to `tflite` format and run `tflite` model in `"stream mode"`. + +1. Configure arguments in `streaming/config/config.py`. +2. `cd streaming`. +3. `python3 runner.py` diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..a8a12f0 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +black +mypy +pytest \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e2afdb6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +ffmpeg-python==0.2.0 +gdown==4.6.3 +julius==0.2.7 +lpips==0.1.4 +musdb==0.4.0 +nobuco +omegaconf==2.3.0 +openunmix==1.2.1 +soundfile==0.12.1 +sox==1.4.1 +stempeg==0.2.3 +sympy==1.12 +tensorflow>=2.13.0 +torch==2.0.1 +torch-audiomentations==0.11.0 +torchaudio==2.0.2 +torchmetrics==0.11.4 +pytorch-lightning==2.0.3 +tqdm==4.65.0 diff --git a/run_docker.sh b/run_docker.sh new file mode 100644 index 0000000..46e60a1 --- /dev/null +++ b/run_docker.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +app=$(pwd) + +docker run --name pmunet -it --rm \ + --net=host --ipc=host \ + --gpus "all" \ + -v ${app}:/app \ + pmunet diff --git a/separator/config/config.py b/separator/config/config.py new file mode 100644 index 0000000..d9e25a1 --- /dev/null +++ b/separator/config/config.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Union + + +from dataclasses import dataclass +from pathlib import Path +from typing import Union + + +@dataclass +class TrainConfig: + + # DATA OPTIONS + musdb_path : str = "musdb18hq" # Directory path where the MUSDB18-HQ dataset is stored. + metadata_train_path : str = "metadata" # Directory path for saving training metadata, like track names and lengths. + metadata_test_path : str = "metadata1" # Directory path for saving testing metadata. + segment : int = 5 # Length (in seconds) of each audio segment used during training. + + # MODEL OPTIONS + model_source : tuple = ("drums", "bass", "other", "vocals") # Sources to target in source separation. + model_depth : int = 4 # The depth of the U-Net architecture. + model_channel : int = 28 # Number of initial channels in U-Net layers. + is_mono : bool = False # Indicates whether the input audio should be treated as mono (True) or stereo (False). + mask_mode : bool = False # Whether to utilize masking within the model. + skip_mode : str = "concat" # Mode of skip connections in U-Net ('concat' for concatenation, 'add' for summation). + nfft : int = 4096 # Number of bins used in STFT. + bottlneck_lstm : bool = True # Determines whether to use LSTM layers as bottleneck in the U-Net architecture. + layers : int = 2 # Number of LSTM layers if bottleneck. + stft_flag : bool = True # A flag to decide whether to apply the STFT is required for tflite. + + # TRAIN OPTIONS + device : str = "cuda" # The computing platform for training: 'cuda' for NVIDIA GPUs or 'cpu'. + batch_size : int = 6 # Batch size for training. + shuffle_train : bool = True # Whether to shuffle the training dataset. + shuffle_valid : bool = False # Whether to shuffle the valid dataset. + drop_last : bool = True # Whether to drop the last incomplete batch in train data. + num_workers : int = 2 # Number of worker processes used for loading data. + metric_monitor_mode : str = "min" # Strategy for monitoring metrics to save model checkpoints. + save_top_k_model_weights : int = 1 # Number of best-performing model weights to save based on the monitored metric. + + factor : int = 1 # Factors for different components of the loss function. + c_factor : int = 1 + + loss_nfft : tuple = (4096,) # Number of FFT bins for calculating loss. + gamma : float = 0.3 # Gamma parameter for adjusting the focus of the loss on certain aspects of the audio spectrum. + lr : float = 0.5 * 3e-3 # Learning rate for the optimizer. + T_0 : int = 40 # Period of the cosine annealing schedule in learning rate adjustment. + max_epochs : int = 100 # Maximum number of training epochs. + precision : str = 16 # Precision of training computations. + grad_clip : float = 0.5 # Gradient clipping value. + + # AUGMENTATION OPTIONS + proba_shift : float = 0.5 # Probability of applying the shift. + shift : int = 8192 # Maximum number of samples for the shift. + proba_flip_channel : float = 1 # Probability of applying the flip left-right channels. + proba_flip_sign : float = 1 # Probability of applying the sign flip. + pitchshift_proba : float = 0.2 # Probability of applying pitch shift. + vocals_min_semitones : int = -5 # The lower limit of vocal semitones. + vocals_max_semitones : int = 5 # The upper limit of vocal semitones. + other_min_semitones : int = -2 # The lower limit of non-vocal semitones. + other_max_semitones : int = 2 # The upper limit of non-vocal semitones. + pitchshift_flag_other : bool = False # Flag to enable pitch shift augmentation on non-vocal sources. + time_change_proba : float = 0.2 # Probability of applying time stretching. + time_change_factors : tuple = (0.8, 0.85, 0.9, 0.95, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3) # Factors for time stretching/compression, defining the range and intensity of this augmentation. + remix_proba : float = 1 # Probability of remixing audio tracks. + remix_group_size : int = batch_size # Size of groups within which shuffling occurs. + scale_proba : float = 1 # Probability of applying the scaling. + scale_min : float = 0.25 # Minimum scaling factor. + scale_max : float = 1.25 # Maximum scaling factor. + fade_mask_proba : float = 0.1 # Probability of applying a fade effect. + double_proba : float = 0.1 # Probability of doubling one channel's audio to both channels. + reverse_proba : float = 0.2 # Probability of reversing a segment of the audio track. + mushap_proba : float = 0.0 # Probability create mashups. + mushap_depth : int = 2 # Number of tracks to mix. + + +@dataclass +class InferenceConfig: + GDRIVE_PREFIX = "https://drive.google.com/uc?id=" # Google Drive URL + + # MODEL OPTIONS + weights_dir : Path = Path("/app/separator/inference/weights") # file name where weights are saved + weights_LSTM_filename : str = "weight_LSTM.pt" # file name model with LSTM + weights_conv_filename : str = "weight_conv.pt" # file name model without LSTM + gdrive_weights_LSTM : str = f"{GDRIVE_PREFIX}1uhAVMvW3x-KL2T2-VkjKjn9K7dTJnoyo" # Google Drive URL that directs weights LSTM + gdrive_weights_conv : str = f"{GDRIVE_PREFIX}1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # Google Drive URL that directs weights without_LSTM + device : str = "cpu" # The computing platform for inference + + # INFERENCE OPTIONS + segment : int = 7 # Length (in seconds) of each audio segment used during inference. + overlap : float = 0.2 # overlapping segments at the beginning of the track and at the end + offset : Union[int, None] = None # start (in seconds) of segment to split + duration : Union[int, None] = None # duration (in seconds) of segment to split, use with `offset` + sample_rate : int = 44100 # sample rate track + num_channels : int = 2 # Number of channels in the audio track + default_result_dir : str = "/app/separator/inference/output" # path file output tracks + default_input_dir : str = "/app/separator/inference/input" # path file input track + + # TEST TRACK + gdrive_mix : str = f"{GDRIVE_PREFIX}1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" # Google Drive URL that directs test track diff --git a/separator/data/dataset.py b/separator/data/dataset.py new file mode 100644 index 0000000..01825fa --- /dev/null +++ b/separator/data/dataset.py @@ -0,0 +1,411 @@ +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +import hashlib +import math +import json +import os +from pathlib import Path +import tqdm +import logging + +import musdb +import julius +import torch as th +import torchaudio as ta +from torch.nn import functional as F + + +@dataclass +class File: + MIXTURE: str = "mixture" + EXT: str = ".wav" + + +def get_musdb_wav_datasets( + musdb="musdb18hq", + segment=11, + shift=1, + train_valid=False, + samplerate=44100, + channels=2, + normalize=True, + metadata="./metadata", + sources=["drums", "bass", "other", "vocals"], + data_type="train", +): + """ + Prepares and retrieves the MUSDB18-HQ dataset for audio source separation. + + This function handles the dataset preparation by creating necessary metadata files and setting up the data loader configuration. It then returns a dataset instance ready for use in training or evaluation. + + Args: + musdb (str): Path to the MUSDB18-HQ dataset directory. + segment (int): Length in seconds of each audio segment for processing. + shift (int): Stride in seconds between consecutive audio segments. + train_valid (bool): Flag to determine if training or validation data should be used. + samplerate (int): Target sample rate for audio processing. + channels (int): Number of audio channels (e.g., 1 for mono, 2 for stereo). + normalize (bool): Whether to normalize audio based on the entire track. + metadata (str): Path for saving the generated metadata. + sources (list[str]): List of source names to be included in the dataset. + data_type (str): Type of data to process ('train' or 'test'). + + Returns: + Wavset: An instance of the Wavset class configured for the specified dataset. + """ + + # Create a unique identifier for the dataset configuration. + sig = hashlib.sha1(str(musdb).encode()).hexdigest()[:8] + + metadata_file = Path(metadata) / f"musdb_{sig}.json" + root = Path(musdb) / data_type + + # Build metadata if not already present. + if not metadata_file.is_file(): + metadata_file.parent.mkdir(exist_ok=True) + metadata_content = MetaData.build_metadata(root, sources) + json.dump(metadata_content, open(metadata_file, "w")) + + # Load metadata from the file. + metadata = json.load(open(metadata_file)) + + # Filter tracks for training or validation based on the configuration. + valid_tracks = _get_musdb_valid() # Retrieve a list of valid track names. + metadata_train = ( + metadata + if train_valid + else {name: meta for name, meta in metadata.items() if name not in valid_tracks} + ) + + # Configure and return the dataset instance. + data_set = Wavset( + root, + metadata_train, + sources, + segment=segment, + shift=shift, + samplerate=samplerate, + channels=channels, + normalize=normalize, + ) + + return data_set + + +def _get_musdb_valid(): + # Return musdb valid set. + import yaml + + setup_path = Path(musdb.__path__[0]) / "configs" / "mus.yaml" + setup = yaml.safe_load(open(setup_path, "r")) + return setup["validation_tracks"] + + +class Wavset: + def __init__( + self, + root, + metadata, + sources, + segment=None, + shift=None, + normalize=True, + samplerate=44100, + channels=2, + ext=File.EXT, + ): + """ + A dataset class for audio source separation, compatible with WAV (or MP3) files. + This class allows training with arbitrary sources, where each audio track is represented by a separate folder within the specified root directory. Each folder should contain audio files for different sources, named as `{source}.{ext}`. + + Args: + root (Path or str): The root directory of the dataset where audio tracks are stored. + metadata (dict): Metadata information generated by the `build_metadata` function. It contains details like track names and lengths. + sources (list[str]): A list of source names to be separated, e.g., ['drums', 'vocals']. + segment (Optional[float]): The length of each audio segment in seconds. If `None`, the entire track is used. + shift (Optional[float]): The stride in seconds between samples. Determines the overlap between consecutive audio segments. + normalize (bool): If True, normalizes the input audio based on the entire track's statistics, not just individual segments. + samplerate (int): The target sample rate. Audio files with a different sample rate will be resampled to this rate. + channels (int): The target number of audio channels. If different, audio will be converted accordingly. + ext (str): The file extension of the audio files, default is '.wav'. + + Note: + The `samplerate` and `channels` parameters are used to ensure consistency across the dataset. They allow on-the-fly conversion of audio properties to match the target specifications. + """ + self.root = Path(root) + self.metadata = OrderedDict(metadata) + self.segment = segment + self.shift = shift or segment + self.normalize = normalize + self.sources = sources + self.channels = channels + self.samplerate = samplerate + self.ext = ext + self.num_examples = [] + for name, meta in self.metadata.items(): + track_duration = meta["length"] / meta["samplerate"] + if segment is None or track_duration < segment: + examples = 1 + else: + examples = int( + math.ceil((track_duration - self.segment) / self.shift) + 1 + ) + self.num_examples.append(examples) + + def __len__(self): + return sum(self.num_examples) + + def get_file(self, name, source): + return self.root / name / f"{source}{self.ext}" + + def __getitem__(self, index): + """ + Get an audio example by index with applied transformations. + + Args: + index (int): The index of the audio example in the dataset. + + Returns: + Tensor: The processed audio example as a tensor. + """ + # Iterate over each audio source and adjust the index for each source + for name, examples in zip(self.metadata, self.num_examples): + if index >= examples: + index -= examples + continue + + # Access metadata for the current source + meta = self.metadata[name] + + # Calculate offset and number of frames if segmenting is enabled + num_frames, offset = -1, 0 + if self.segment is not None: + offset = int(meta["samplerate"] * self.shift * index) + num_frames = int(math.ceil(meta["samplerate"] * self.segment)) + + # Load and process audio from each source + wavs = [] + for source in self.sources: + file_path = self.get_file(name, source) + wav, _ = ta.load( + str(file_path), frame_offset=offset, num_frames=num_frames + ) + wav = self.__convert_audio_channels(wav, self.channels) + wavs.append(wav) + + # Stack, resample, and normalize the audio examples + example = th.stack(wavs) + example = julius.resample_frac(example, meta["samplerate"], self.samplerate) + if self.normalize: + example = (example - meta["mean"]) / meta["std"] + + # Pad the audio example if segmenting is used + if self.segment: + length = int(self.segment * self.samplerate) + example = example[..., :length] + example = F.pad(example, (0, length - example.shape[-1])) + + return example + + def __convert_audio_channels(self, wav, desired_channels=2): + """ + Convert an audio waveform to the specified number of channels. + + Args: + wav (Tensor): The input waveform tensor with shape (..., channels, length). + desired_channels (int, optional): The number of channels for the output waveform. + Defaults to 2 for stereo output. + + Returns: + Tensor: The waveform with the desired number of channels. + + Raises: + ValueError: If the input audio has fewer channels than requested and is not mono. + + Description: + - If the input already has the desired number of channels, it is returned as is. + - If a mono to stereo conversion is needed, the mono channel is duplicated. + - If downmixing is needed (e.g., from 5.1 to stereo), only the first 'desired_channels' are kept. + - If upmixing is required (e.g., mono to 5.1), the single channel is replicated across all desired channels. + """ + + *shape, src_channels, length = wav.shape + + if src_channels == desired_channels: + # No change needed + return wav + elif src_channels > desired_channels: + # Downmix by slicing to the desired number of channels + return wav[..., :desired_channels, :] + elif src_channels == 1: + # Upmix by replicating the mono channel + return wav.expand(*shape, desired_channels, length) + else: + # Invalid case: input has fewer channels than desired and is not mono + raise ValueError( + "Cannot upmix from fewer than 1 channel unless the source is mono." + ) + + +class MetaData: + @staticmethod + def __track_metadata(track, sources, normalize=True, ext=File.EXT): + """ + Process and return the metadata for a single track. + + Args: + track (Path): Path to the track directory. + sources (list[str]): List of sources to look for. + normalize (bool): If True, calculates normalization values. + ext (str): Extension of audio files. + + Returns: + dict: Dictionary containing the track's metadata. + + Raises: + RuntimeError: If an audio file is invalid. + ValueError: If audio files have inconsistent lengths or sample rates. + """ + track_length, track_samplerate = None, None + mean, std = 0, 1 + + for source in sources + [File.MIXTURE]: + source_file = track / f"{source}{ext}" + if source == File.MIXTURE and not source_file.exists(): + audio, sr = MetaData.__create_mixture(track, sources, ext) + ta.save(source_file, audio, sr, encoding="PCM_F") + + try: + info = ta.info(str(source_file)) + except RuntimeError: + logging.error(f"{source_file} is invalid") + raise + + length, sample_rate = MetaData.__validate_track( + info, track_length, track_samplerate, source_file + ) + if track_length is None: + track_length, track_samplerate = length, sample_rate + + if source == File.MIXTURE and normalize: + mean, std = MetaData.__calculate_normalization(source_file) + + return { + "length": length, + "mean": mean, + "std": std, + "samplerate": track_samplerate, + } + + @staticmethod + def build_metadata(path, sources, normalize=True, ext=File.EXT): + """ + Build and return the metadata for the entire dataset. + + Args: + path (str or Path): Path to the dataset. + sources (list[str]): List of sources to look for. + normalize (bool): If True, calculates normalization values. + ext (str): Extension of audio files. + + Returns: + dict: Dictionary containing metadata for each track in the dataset. + """ + meta = {} + path = Path(path) + pendings = [] + + with ThreadPoolExecutor(8) as pool: + for root, _, _ in os.walk(path, followlinks=True): + root = Path(root) + if root.name.startswith(".") or root == path: + continue + name = str(root.relative_to(path)) + pendings.append( + ( + name, + pool.submit( + MetaData.__track_metadata, root, sources, normalize, ext + ), + ) + ) + + for name, pending in tqdm.tqdm(pendings, ncols=120): + meta[name] = pending.result() + + return meta + + @staticmethod + def __create_mixture(track, sources, ext): + """ + Create and return the audio mixture from individual sources. + + Args: + track (Path): Path to the track directory. + sources (list[str]): List of sources to look for. + ext (str): Extension of audio files. + + Returns: + Tuple[Tensor, int]: The mixture audio tensor and its sample rate. + """ + audio = 0 + for sub_source in sources: + sub_file = track / f"{sub_source}{ext}" + sub_audio, sr = ta.load(sub_file) + audio += sub_audio + + would_clip = audio.abs().max() >= 1 + if would_clip: + assert ta.get_audio_backend() == "soundfile", "use dset.backend=soundfile" + + return audio, sr + + @staticmethod + def __validate_track(info, track_length, track_samplerate, source_file): + """ + Validate the track's length and sample rate. + + Args: + info (AudioMetaData): Metadata of the audio file. + track_length (int): Expected length of the track. + track_samplerate (int): Expected sample rate of the track. + source_file (Path): Path to the source file. + + Returns: + Tuple[int, int]: Length and sample rate of the track. + + Raises: + ValueError: If the track's length or sample rate is inconsistent. + """ + length = info.num_frames + if track_length is not None and track_length != length: + raise ValueError( + f"Invalid length for file {source_file}: " + f"expecting {track_length} but got {length}." + ) + elif track_samplerate is not None and info.sample_rate != track_samplerate: + raise ValueError( + f"Invalid sample rate for file {source_file}: " + f"expecting {track_samplerate} but got {info.sample_rate}." + ) + return length, info.sample_rate + + @staticmethod + def __calculate_normalization(source_file): + """ + Calculate and return the mean and standard deviation for normalization. + + Args: + source_file (Path): Path to the source file. + + Returns: + Tuple[float, float]: Mean and standard deviation of the waveform. + """ + try: + wav, _ = ta.load(str(source_file)) + except RuntimeError: + logging.error(f"{source_file} is invalid") + raise + wav = wav.mean(0) + return wav.mean().item(), wav.std().item() diff --git a/separator/inference.py b/separator/inference.py new file mode 100644 index 0000000..25255c6 --- /dev/null +++ b/separator/inference.py @@ -0,0 +1,233 @@ +import argparse +import gdown +import os +from pathlib import Path + +import torch +import torchaudio +from torchaudio.transforms import Fade + +from model.PM_Unet import Model_Unet + + +class InferenceModel: + def __init__(self, config, model_bottlneck_lstm=True, weights_path=""): + self.config = config + self.model_bottlneck_lstm = model_bottlneck_lstm + + weights_path = "" if weights_path is None else weights_path + if Path(weights_path).is_file(): + self.weights_path = weights_path + else: + self.resolve_weigths() + + self.model = Model_Unet( + source=["drums", "bass", "other", "vocals"], + depth=4, + channel=28, + bottlneck_lstm=model_bottlneck_lstm, + ) + + self.model.load_state_dict( + torch.load(str(self.weights_path), map_location=torch.device("cpu")) + ) + + self.segment = self.config.segment + self.overlap = self.config.overlap + + def resolve_weigths(self): + if self.model_bottlneck_lstm: + self.weights_path = ( + self.config.weights_dir / self.config.weights_LSTM_filename + ) + gdrive_url = self.config.gdrive_weights_LSTM + else: + self.weights_path = ( + self.config.weights_dir / self.config.weights_conv_filename + ) + gdrive_url = self.config.gdrive_weights_conv + + try: + self.config.weights_dir.mkdir(exist_ok=False, parents=True) + download_weights = True + except FileExistsError: + try: + Path(self.weights_path).touch(exist_ok=False) + download_weights = True + except FileExistsError: + download_weights = False + + if download_weights: + gdown.download(gdrive_url, str(self.weights_path)) + + def track(self, sample_mixture_path, output_dir): + if sample_mixture_path == self.config.default_input_dir: + sample_mixture_path = self.resolve_default_sample() + output_path = Path(output_dir) / Path(sample_mixture_path).stem + output_path.mkdir(exist_ok=True, parents=True) + + offset = self.config.offset + duration = self.config.duration + waveform, sr = torchaudio.load(sample_mixture_path) + + start = sr * offset if offset else None + end = sr * (offset + duration) if duration else None + mixture = waveform[:, start:end] + + # Normalize + ref = mixture.mean(0) + mixture = (mixture - ref.mean()) / ref.std() + + # Do separation + sources = self.separate_sources(mixture[None], sample_rate=sr) + + # Denormalize + sources = sources * ref.std() + ref.mean() + + sources_list = ["drums", "bass", "other", "vocals"] + sources_ouputs = {s: f"{str(output_path)}/{s}.wav" for s in sources_list} + + B, S, C, T = sources.shape + sources = ( + sources.view(B, S * C, T) + / sources.view(B, S * C, T).max(dim=2)[0].unsqueeze(-1) + ).view(B, S, C, T) + sources = list(sources) + + audios = dict(zip(sources_list, sources[0])) + for k, v in audios.items(): + audios[k] = {"source": v, "path": sources_ouputs[k]} + + return audios + + def separate_sources(self, mix, sample_rate): + """ + Separates the audio mix into its constituent sources. + + Args: + mix (Tensor): The input mixed audio signal tensor of shape (batch, channels, length). + sample_rate (int): The sample rate of the audio signal. + + Returns: + Tensor: The separated audio sources as a tensor. + """ + # Set the device based on the configuration or input mix + device = torch.device(self.config.device) if self.config.device else mix.device + + # Get the shape of the input mix + batch, channels, length = mix.shape + + # Calculate chunk length for processing and overlap frames + chunk_len = int(sample_rate * self.segment * (1 + self.overlap)) + overlap_frames = int(self.overlap * sample_rate) + fade = Fade(fade_in_len=0, fade_out_len=overlap_frames, fade_shape="linear") + + # Initialize the tensor to hold the final separated sources + num_sources = 4 # ["drums", "bass", "other", "vocals"] + final = torch.zeros(batch, num_sources, channels, length, device=device) + + start, end = 0, chunk_len + while start < length - overlap_frames: + # Process each chunk with model and apply fade + chunk = mix[:, :, start:end] + with torch.no_grad(): + separated_sources = self.model.forward(chunk) + separated_sources = fade(separated_sources) + final[:, :, :, start:end] += separated_sources + + # Adjust the start and end for the next chunk, and update fade parameters + start, end = self.__update_chunk_indices( + start, end, chunk_len, overlap_frames, length, fade + ) + + return final + + @staticmethod + def __update_chunk_indices(start, end, chunk_len, overlap_frames, length, fade): + """ + Update the chunk indices for the next iteration and adjust fade parameters. + + Args: + start (int): Current start index of the chunk. + end (int): Current end index of the chunk. + chunk_len (int): Length of each chunk. + overlap_frames (int): Number of overlapping frames. + length (int): Total length of the audio signal. + fade (Fade): The Fade object used for applying fade in/out. + + Returns: + Tuple[int, int]: The updated start and end indices for the next chunk. + """ + if start == 0: + fade.fade_in_len = overlap_frames + start += chunk_len - overlap_frames + else: + start += chunk_len + + end = min(end + chunk_len, length) + fade.fade_out_len = 0 if end >= length else overlap_frames + + return start, end + + def resolve_default_sample(self): + default_input_dir = self.config.default_input_dir + Path(default_input_dir).mkdir(parents=True, exist_ok=True) + + default_sample_path = f"{default_input_dir}/sample.wav" + try: + Path(default_sample_path).touch(exist_ok=False) + gdown.download(self.config.gdrive_mix, default_sample_path) + except FileExistsError: + pass + + return default_sample_path + + +def main(args, config): + inf_model = InferenceModel(config, weights_path=args.weights_path) + audios = inf_model.track(args.mix_path, args.out_dir) + + torchaudio.save( + audios["drums"]["path"], audios["drums"]["source"], config.sample_rate + ) + torchaudio.save( + audios["bass"]["path"], audios["bass"]["source"], config.sample_rate + ) + torchaudio.save( + audios["other"]["path"], audios["other"]["source"], config.sample_rate + ) + torchaudio.save( + audios["vocals"]["path"], audios["vocals"]["source"], config.sample_rate + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Inference script") + from config.config import InferenceConfig + + config = InferenceConfig() + + parser.add_argument( + "-I", + dest="mix_path", + help="path to mixture", + default=config.default_input_dir, + type=str, + ) + parser.add_argument( + "-O", + dest="out_dir", + help="specified output dir", + default=config.default_result_dir, + type=str, + ) + parser.add_argument( + "-w", + dest="weights_path", + help="specified path to weights", + required=False, + type=str, + ) + + args = parser.parse_args() + main(args, config) diff --git a/separator/model/PM_Unet.py b/separator/model/PM_Unet.py new file mode 100644 index 0000000..5a5894f --- /dev/null +++ b/separator/model/PM_Unet.py @@ -0,0 +1,253 @@ +import torch +import torch as th +import torch.nn as nn +from model.STFT import STFT +from model.modules import Encoder, Decoder, Bottleneck_v2, Bottleneck +from typing import List, Optional + + +class Model_Unet(nn.Module): + def __init__( + self, + depth: int = 4, + source: List[str] = ["drums", "bass", "other", "vocals"], + channel: int = 28, + is_mono: Optional[bool] = False, + mask_mode: Optional[bool] = False, + skip_mode: str = "concat", + nfft: int = 4096, + bottlneck_lstm: Optional[bool] = True, + layers: int = 2, + stft_flag: bool = True, + ): + """ + depth (int): Number of layers in both the encoder and decoder. + source (list[str]): List of source names. + channel (int): Initial number of hidden channels. + is_mono (bool): Indicates whether the input/output audio channel is mono. + mask_mode (bool): Enables or disables mask inference. + skip_mode (str): Type of skip connection, either 'concat' or 'add'. + concat: Concatenates the outputs of the encoder and decoder. + add: Adds the outputs of the encoder and decoder. + nfft (int): Number of FFT (Fast Fourier Transform) bins. + bottleneck_lstm (bool): Determines the type of bottleneck to use. + True: Uses a BiLSTM (bidirectional Long Short-Term Memory) bottleneck. + False: Uses a convolutional bottleneck. + layers (int): Number of bottleneck LSTM layers. + stft_flag (bool): Indicates whether to use STFT (Short-Time Fourier Transform). + """ + super().__init__() + self.sources = source + skip_channel = 2 if skip_mode == "concat" else 1 + self.skip_mode = True if skip_mode == "concat" else False + stereo = 1 if is_mono else 2 + self.mask_mode = mask_mode + + norm = self.__norm("InstanceNorm2d") + act = self.__get_act("gelu") + self.stft = STFT(nfft) + self.stft_flag = stft_flag + self.conv_magnitude = nn.Conv2d( + in_channels=stereo, + out_channels=channel, + kernel_size=1, + stride=1, + bias=False, + ) + + self.conv_magnitude_final = nn.Conv2d( + in_channels=channel, + out_channels=len(source) * stereo, + kernel_size=1, + stride=1, + bias=False, + ) + + self.conv_phase = nn.Conv2d( + in_channels=stereo, + out_channels=channel, + kernel_size=1, + stride=1, + bias=False, + ) + + self.conv_phase_final = nn.Conv2d( + in_channels=channel, + out_channels=len(source) * stereo, + kernel_size=1, + stride=1, + bias=False, + ) + + self.encoder_magnitude = nn.ModuleList() + self.decoder_magnitude = nn.ModuleList() + + self.encoder_phase = nn.ModuleList() + self.decoder_phase = nn.ModuleList() + + for idx in range(depth): + self.encoder_magnitude.append( + Encoder( + input_channel=channel, + out_channel=channel * 2, + scale=(4, 1), + stride=(4, 1), + padding=0, + normalization=norm, + activation=act, + ) + ) + + self.encoder_phase.append( + Encoder( + input_channel=channel, + out_channel=channel * 2, + scale=(4, 1), + stride=(4, 1), + padding=0, + normalization=norm, + activation=act, + ) + ) + channel *= 2 + + self.bottlneck_lstm = bottlneck_lstm + if self.bottlneck_lstm: + self.bottleneck_magnitude = Bottleneck_v2( + input_channel=channel * (nfft // 2) // (2 ** (2 * depth)), + out_channel=channel, + normalization=nn.InstanceNorm1d, + activation=act, + layers=layers, + ) + else: + self.bottleneck_magnitude = Bottleneck( + input_channel=channel, + out_channels=channel, + normalization=norm, + activation=act, + ) + + self.bottleneck_phase = Bottleneck(channel, channel, norm, act) + + for idx in range(depth): + self.decoder_magnitude.append( + Decoder( + input_channel=channel * skip_channel, + out_channel=channel // 2, + scale=(4, 1), + stride=(4, 1), + padding=0, + normalization=norm, + activation=act, + ) + ) + self.decoder_phase.append( + Decoder( + input_channel=channel * skip_channel, + out_channel=channel // 2, + scale=(4, 1), + stride=(4, 1), + padding=0, + normalization=norm, + activation=act, + ) + ) + channel //= 2 + + def __wave2feature(self, z: torch.Tensor): + phase = th.atan2(z.imag, z.real) + magnitude = z.abs() + return magnitude, phase + + def __get_act(self, act_type: str): + if act_type == "gelu": + return nn.GELU() + elif act_type == "relu": + return nn.ReLU() + elif act_type[:3] == "elu": + alpha = float(act_type.replace("elu", "")) + return nn.ELU(alpha) + else: + raise Exception + + def __norm(self, norm_type: str): + if norm_type == "BatchNorm": + return nn.BatchNorm2d + elif norm_type == "InstanceNorm2d": + return nn.InstanceNorm2d + elif norm_type == "InstanceNorm1d": + return nn.InstanceNorm1d + else: + return nn.Identity() + + def __normal(self, x: torch.Tensor): # normalization input signal + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + return mean, std, x + + def forward(self, x: torch.Tensor): + if self.stft_flag: + z = self.stft.stft(x) + length_wave = x.shape[-1] + else: + z = x + x_m, x_p = self.__wave2feature(z) + + B, C, Fq, T = x_m.shape + S = len(self.sources) + + # normalization magnitude input + mean_m, std_m, x_m = self.__normal(x_m) + x_mix = x_m + # normalization magnitude phase input + mean_p, std_p, x_p = self.__normal(x_p) + + skip_m = [] # skip connection magnitude branch + skip_p = [] # skip connection phase branch + + x_m = self.conv_magnitude(x_m) # start conv magnitude + x_p = self.conv_phase(x_p) # start conv phase + + for idx_enc in range(len(self.encoder_magnitude)): + x_m = self.encoder_magnitude[idx_enc](x_m) # encoder layer magnitude + x_p = self.encoder_phase[idx_enc](x_p) # encoder layer phase + + skip_m.append(x_m) # skip magnitude + skip_p.append(x_p) # skip phase + + x_m = self.bottleneck_magnitude(x_m) + x_p = self.bottleneck_phase(x_p) + + for idx in range(len(self.decoder_magnitude)): + if self.skip_mode: + x_m = self.decoder_magnitude[idx]( + torch.concat((x_m, skip_m[-idx - 1]), dim=1) + ) # decoder layer magnitude + x_p = self.decoder_phase[idx]( + torch.concat((x_p, skip_p[-idx - 1]), dim=1) + ) # decoder layer phase + else: + x_m = self.decoder_magnitude[idx](x_m + skip_m[-idx - 1]) + x_p = self.decoder_phase[idx](x_p + skip_p[-idx - 1]) + + x_m = self.conv_magnitude_final(x_m) # final conv magnitude + x_p = self.conv_phase_final(x_p) + + x_m = x_m.view(B, S, -1, Fq, T) + x_p = x_p.view(B, S, -1, Fq, T) + + if self.mask_mode: + mask = nn.functional.softmax(x_m.view(B, S, -1, Fq, T), dim=1) + x_m = x_mix.view(B, 1, C, Fq, T) * mask + + x_m = x_m * std_m[:, None] + mean_m[:, None] + x_p = x_p * std_p[:, None] + mean_p[:, None] + + imag = x_m * th.sin(x_p) + real = x_m * th.cos(x_p) + z = th.complex(real, imag) + if self.stft_flag: + return self.stft.istft(z, length_wave) + return z diff --git a/separator/model/STFT.py b/separator/model/STFT.py new file mode 100644 index 0000000..9e9f572 --- /dev/null +++ b/separator/model/STFT.py @@ -0,0 +1,104 @@ +import torch +import torch as th +from typing import Tuple, Optional, Union +import math +import torch.nn.functional as F + + +class STFT: + def __init__(self, n_fft: int = 4096, pad: int = 0): + self.n_fft = n_fft + self.pad = pad + self.hop_length = self.n_fft // 4 + + def __pad1d( + self, + x: torch.Tensor, + paddings: Tuple[int, int], + mode: str = "constant", + value: float = 0.0, + ): + """ + Tiny wrapper around F.pad, designed to allow reflect padding on small inputs. + If the input is too small for reflect padding, we first add extra zero padding to the right before reflection occurs. + """ + x0 = x + length = x.shape[-1] + padding_left, padding_right = paddings + if mode == "reflect": + max_pad = max(padding_left, padding_right) + if length <= max_pad: + extra_pad = max_pad - length + 1 + extra_pad_right = min(padding_right, extra_pad) + extra_pad_left = extra_pad - extra_pad_right + paddings = ( + padding_left - extra_pad_left, + padding_right - extra_pad_right, + ) + x = F.pad(x, (extra_pad_left, extra_pad_right)) + out = F.pad(x, paddings, mode, value) + assert out.shape[-1] == length + padding_left + padding_right + assert (out[..., padding_left : padding_left + length] == x0).all() + return out + + def _spec(self, x: torch.Tensor): + *other, length = x.shape + x = x.reshape(-1, length) + z = th.stft( + x, + self.n_fft * (1 + self.pad), + self.hop_length or self.n_fft // 4, + window=th.hann_window(self.n_fft).to(x), + win_length=self.n_fft, + normalized=True, + center=True, + return_complex=True, + pad_mode="reflect", + ) + _, freqs, frame = z.shape + return z.view(*other, freqs, frame) + + def _ispec(self, z: torch.Tensor, length: int): + *other, freqs, frames = z.shape + n_fft = 2 * freqs - 2 + z = z.view(-1, freqs, frames) + win_length = n_fft // (1 + self.pad) + is_mps = z.device.type == "mps" + if is_mps: + z = z.cpu() + z = th.istft( + z, + n_fft, + self.hop_length, + window=th.hann_window(win_length).to(z.real), + win_length=win_length, + normalized=True, + length=length, + center=True, + ) + _, length = z.shape + return z.view(*other, length) + + def stft(self, x: torch.Tensor): + hl = self.hop_length + x0 = x # noqa + le = int(math.ceil(x.shape[-1] / self.hop_length)) + pad = hl // 2 * 3 + x = self.__pad1d( + x, (pad, pad + le * self.hop_length - x.shape[-1]), mode="reflect" + ) + z = self._spec(x)[..., :-1, :] + z = z[..., 2 : 2 + le] + return z + + def istft(self, z: torch.Tensor, length: int = 0, scale: Optional[int] = 0): + hl = self.hop_length // (4**scale) + z = F.pad(z, (0, 0, 0, 1)) + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + + x = self._ispec(z, length=le) + + x = x[..., pad : pad + length] + return x diff --git a/separator/model/modules.py b/separator/model/modules.py new file mode 100644 index 0000000..063bc10 --- /dev/null +++ b/separator/model/modules.py @@ -0,0 +1,426 @@ +import torch +import torch.nn as nn +import math +from torch.nn import functional as F + + +class DownSample(nn.Module): + """ + DownSample - dimensionality reduction block that includes layer normalization, activation layer, and Conv2d layer. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + scale (int, tuple): Kernel size. + stride (int, tuple): Stride of the convolution. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ + + def __init__( + self, + input_channel, + out_channel, + scale, + stride, + padding, + activation, + normalization, + ): + super().__init__() + + self.conv_layer = nn.Sequential( + normalization(input_channel), + activation, + nn.Conv2d( + in_channels=input_channel, + out_channels=out_channel, + kernel_size=scale, + stride=stride, + padding=padding, + bias=False, + ), + ) + + def forward(self, x): + return self.conv_layer(x) + + +class UpSample(nn.Module): + """ + UpSample - dimensionality boosting block that includes layer normalization, activation layer, and Conv2d layer. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + scale (int, tuple): Kernel size. + stride (int, tuple): Stride of the convolution. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ + + def __init__( + self, + input_channel, + out_channel, + scale, + stride, + padding, + activation, + normalization, + ): + super().__init__() + + self.convT_layer = nn.Sequential( + normalization(input_channel), + activation, + nn.ConvTranspose2d( + in_channels=input_channel, + out_channels=out_channel, + kernel_size=scale, + stride=stride, + padding=padding, + bias=False, + ), + ) + + def forward(self, x): + return self.convT_layer(x) + + +class InceptionBlock(nn.Module): + """ + InceptionBlock: This block comprises three branches, each consisting of normalization layers, activation layers, and 2D convolution layers. The convolution layers in each branch have kernel sizes of 1, 3, and 5, respectively. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ + + def __init__(self, input_channel, out_channel, activation, normalization): + super().__init__() + + self.conv_layer_1 = nn.Sequential( + normalization(input_channel), + activation, + nn.Conv2d( + in_channels=input_channel, + out_channels=out_channel, + kernel_size=1, + stride=1, + bias=False, + ), + ) + + self.conv_layer_2 = nn.Sequential( + normalization(input_channel), + activation, + nn.Conv2d( + in_channels=input_channel, + out_channels=out_channel, + kernel_size=3, + stride=1, + padding="same", + bias=False, + ), + ) + + self.conv_layer_3 = nn.Sequential( + normalization(input_channel), + activation, + nn.Conv2d( + in_channels=input_channel, + out_channels=out_channel, + kernel_size=5, + stride=1, + padding="same", + bias=False, + ), + ) + + def forward(self, x): + x1 = self.conv_layer_1(x) + x2 = self.conv_layer_2(x) + x3 = self.conv_layer_3(x) + return torch.concat((x1, x2, x3), dim=1) + + +class Encoder(nn.Module): + """ + Encoder layer - Block included DownSample layer and InceptionBlock. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + scale (int, tuple): The size of the kernel used in the DownSample layer. + stride (int, tuple): The stride used in the DownSample layer. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ + + def __init__( + self, + input_channel, + out_channel, + scale, + stride, + padding, + activation, + normalization, + ): + super().__init__() + + self.inception_layer = InceptionBlock( + input_channel, out_channel, activation, normalization + ) + self.down_layer = DownSample( + out_channel * 3, + out_channel, + scale, + stride, + padding, + activation, + normalization, + ) + + def forward(self, x): + x = self.inception_layer(x) + x = self.down_layer(x) + return x + + +class Decoder(nn.Module): + """ + Decoder layer - Block included UpSample layer and InceptionBlock. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + scale (int, tuple): The size of the kernel used in the UpSample layer. + stride (int, tuple): The stride used in the UpSample layer. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ + + def __init__( + self, + input_channel, + out_channel, + scale, + stride, + padding, + activation, + normalization, + ): + super().__init__() + + self.inception_layer = InceptionBlock( + input_channel, out_channel, activation, normalization + ) + self.up_layer = UpSample( + out_channel * 3, + out_channel, + scale, + stride, + padding, + activation, + normalization, + ) + + def forward(self, x): + x = self.inception_layer(x) + x = self.up_layer(x) + return x + + +class BLSTM(nn.Module): + """ + A bidirectional LSTM (BiLSTM) module with the same number of hidden units as the input dimension. + This module can process inputs in overlapping chunks if `max_steps` is specified. + In this case, the input will be split into chunks, and the LSTM will be applied to each chunk separately. + Args: + dim (int): The number of dimensions in the input and the hidden state of the LSTM. + max_steps (int, optional): The maximum number of steps (length of chunks) for processing the input. Defaults to None. + skip (bool, optional): Flag to enable skip connections. Defaults to False. + layers (int): Number of recurrent layers + """ + + def __init__(self, dim, layers=1, max_steps=None, skip=False): + super().__init__() + assert max_steps is None or max_steps % 4 == 0 + self.max_steps = max_steps + self.lstm = nn.LSTM( + bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim + ) + self.linear = nn.Linear(2 * dim, dim) + self.skip = skip + + def __unfold(self, a, kernel_size, stride): + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + + This will pad the input so that `F = ceil(T / K)`. + + see https://github.com/pytorch/pytorch/issues/60466 + """ + *shape, length = a.shape + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(a, (0, tgt_length - length)) + strides = list(a.stride()) + assert strides[-1] == 1, "data should be contiguous" + strides = strides[:-1] + [stride, 1] + return a.as_strided([*shape, n_frames, kernel_size], strides) + + def forward(self, x): + B, C, T = x.shape + y = x + framed = False + if self.max_steps is not None and T > self.max_steps: + width = self.max_steps + stride = width // 2 + frames = self.__unfold(x, width, stride) + nframes = frames.shape[2] + framed = True + x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) + + x = x.permute(2, 0, 1) + + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + if framed: + out = [] + frames = x.reshape(B, -1, C, width) + limit = stride // 2 + for k in range(nframes): + if k == 0: + out.append(frames[:, k, :, :-limit]) + elif k == nframes - 1: + out.append(frames[:, k, :, limit:]) + else: + out.append(frames[:, k, :, limit:-limit]) + out = torch.cat(out, -1) + out = out[..., :T] + x = out + if self.skip: + x = x + y + return x + + +class Bottleneck_v2(nn.Module): + """ + Bottleneck - bi-lstm bottleneck + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + layers (int): number of recurrent layers + skip (bool): include skip conncetion bi-lstm + stride (int, tuple): The stride used in the Conv1d layer. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ + + def __init__( + self, + input_channel, + out_channel, + activation, + normalization, + layers=2, + max_steps=200, + skip=True, + stride=1, + padding="same", + ): + super().__init__() + + self.conv_layer = nn.Sequential( + normalization(input_channel, affine=True), + activation, + nn.Conv1d( + input_channel, + out_channel, + kernel_size=3, + stride=stride, + padding=padding, + ), + ) + + self.biLSTM = BLSTM(out_channel, layers=layers, max_steps=max_steps, skip=skip) + + self.conv_layer_1 = nn.Sequential( + normalization(out_channel, affine=True), + activation, + nn.Conv1d(out_channel, input_channel, kernel_size=1, stride=stride), + ) + + def forward(self, x): + B, C, F, T = x.shape + x = x.view(B, C * F, T) + x = self.conv_layer(x) + x = self.biLSTM(x) + x = self.conv_layer_1(x) + x = x.view(B, C, F, T) + return x + + +class Bottleneck(nn.Module): + """ + Bottleneck - convolution bottleneck + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ + + def __init__(self, input_channel, out_channels, normalization, activation): + super().__init__() + + self.conv_layer_1 = nn.Sequential( + normalization(input_channel), + activation, + nn.Conv2d( + in_channels=input_channel, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding="same", + bias=False, + ), + ) + + self.conv_layer_2 = nn.Sequential( + normalization(input_channel), + activation, + nn.Conv2d( + in_channels=input_channel, + out_channels=out_channels, + kernel_size=2, + stride=1, + padding="same", + bias=False, + ), + ) + + self.conv_layer_3 = nn.Sequential( + normalization(input_channel), + activation, + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + bias=False, + ), + ) + + def forward(self, x): + x1 = self.conv_layer_1(x) + x2 = self.conv_layer_2(x) + x1 + x3 = self.conv_layer_3(x2) + x1 + + return x3 diff --git a/separator/pl_model.py b/separator/pl_model.py new file mode 100644 index 0000000..cf6c9e9 --- /dev/null +++ b/separator/pl_model.py @@ -0,0 +1,355 @@ +from model.PM_Unet import Model_Unet + +from train.loss import MultiResSpecLoss + +from train import augment + +from pathlib import Path +import torch +import torch.nn as nn +import pytorch_lightning as pl +from torchmetrics import ScaleInvariantSignalDistortionRatio +from torchmetrics.functional.audio import ( + scale_invariant_signal_distortion_ratio, + signal_distortion_ratio, +) + + +class PM_model(pl.LightningModule): + def __init__(self, config): + super().__init__() + + self.model = Model_Unet( + depth=config.model_depth, + source=config.model_source, + channel=config.model_channel, + is_mono=config.is_mono, + mask_mode=config.mask_mode, + skip_mode=config.skip_mode, + nfft=config.nfft, + bottlneck_lstm=config.bottlneck_lstm, + layers=config.layers, + stft_flag=config.stft_flag, + ) + + # loss + # Loss = (L_1 + L_{MRS} - L_{SISDR}) + self.criterion_1 = nn.L1Loss() + self.criterion_2 = MultiResSpecLoss( + factor=config.factor, + f_complex=config.c_factor, + gamma=config.gamma, + n_ffts=config.loss_nfft, + ) + self.criterion_3 = ScaleInvariantSignalDistortionRatio() + + # augment + self.augment = [ + augment.Shift(proba=config.proba_shift, shift=config.shift, same=True) + ] + self.augment += [ + augment.PitchShift( + proba=config.pitchshift_proba, + min_semitones=config.vocals_min_semitones, + max_semitones=config.vocals_max_semitones, + min_semitones_other=config.other_min_semitones, + max_semitones_other=config.other_max_semitones, + flag_other=config.pitchshift_flag_other, + ), + augment.TimeChange( + factors_list=config.time_change_factors, proba=config.time_change_proba + ), + augment.FlipChannels(proba=config.proba_flip_channel), + augment.FlipSign(proba=config.proba_flip_sign), + augment.Remix(proba=config.remix_proba, group_size=config.remix_group_size), + augment.Scale( + proba=config.scale_proba, min=config.scale_min, max=config.scale_max + ), + augment.FadeMask(proba=config.fade_mask_proba), + augment.Double(proba=config.double_proba), + augment.Reverse(proba=config.reverse_proba), + augment.RemixWave( + proba=config.mushap_proba, group_size=config.mushap_depth + ), + ] + self.augment = torch.nn.Sequential(*self.augment) + + self.model.apply(self.__init_weights) + + def __init_weights(self, m): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + torch.nn.init.xavier_uniform(m.weight) + + def __usdr(self, predT, tgtT, delta=1e-7): + """ + latex: $usdr=10\log_{10} (\dfrac{\| tgtT\|^2 + \delta}{ \| predT - tgtT\| ^{2} + \delta})$ + """ + num = torch.sum(torch.square(tgtT), dim=(1, 2)) + den = torch.sum(torch.square(tgtT - predT), dim=(1, 2)) + num += delta + den += delta + usdr = 10 * torch.log10(num / den) + return usdr.mean() + + def forward(self, x): + x = self.model(x) + return x + + def loss(self, y_true, y_pred): + # losses are averaged + loss = ( + self.criterion_1(y_pred, y_true) + + self.criterion_2(y_pred, y_true) + - self.criterion_3(y_pred, y_true) + ) / 3 + return loss + + def training_step(self, batch, batch_idx): + source = batch + source = self.augment(source) + mix = source.sum(dim=1) + + source_predict = self.model(mix) + + drums_pred, drums_target = source_predict[:, 0], source[:, 0] + bass_pred, bass_target = source_predict[:, 1], source[:, 1] + other_pred, other_target = source_predict[:, 2], source[:, 2] + vocals_pred, vocals_target = source_predict[:, 3], source[:, 3] + + drums_loss = self.loss(drums_pred, drums_target) + + bass_loss = self.loss(bass_pred, bass_target) + + other_loss = self.loss(other_pred, other_target) + + vocals_loss = self.loss(vocals_pred, vocals_target) + + loss = 0.25 * ( + drums_loss + bass_loss + other_loss + vocals_loss + ) # losses averaged across sources + + self.log_dict( + { + "train_loss": loss, + "train_drums": drums_loss, + "train_bass": bass_loss, + "train_other": other_loss, + "train_vocals": vocals_loss, + }, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + self.log_dict( + { + "train_drums_sdr": signal_distortion_ratio( + drums_pred, drums_target + ).mean(), + "train_bass_sdr": signal_distortion_ratio( + bass_pred, bass_target + ).mean(), + "train_other_sdr": signal_distortion_ratio( + other_pred, other_target + ).mean(), + "train_vocals_sdr": signal_distortion_ratio( + vocals_pred, vocals_target + ).mean(), + }, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + + self.log_dict( + { + "train_drums_sisdr": scale_invariant_signal_distortion_ratio( + drums_pred, drums_target + ).mean(), + "train_bass_sisdr": scale_invariant_signal_distortion_ratio( + bass_pred, bass_target + ).mean(), + "train_other_sisdr": scale_invariant_signal_distortion_ratio( + other_pred, other_target + ).mean(), + "train_vocals_sisdr": scale_invariant_signal_distortion_ratio( + vocals_pred, vocals_target + ).mean(), + }, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + + self.log_dict( + { + "train_drums_usdr": self.__usdr(drums_pred, drums_target).mean(), + "train_bass_usdr": self.__usdr(bass_pred, bass_target).mean(), + "train_other_usdr": self.__usdr(other_pred, other_target).mean(), + "train_vocals_usdr": self.__usdr(vocals_pred, vocals_target).mean(), + }, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + return loss + + def validation_step(self, batch, batch_idx): + source = batch + + mix = source.sum(dim=1) + + source_predict = self.model(mix) + drums_pred, drums_target = source_predict[:, 0], source[:, 0] + bass_pred, bass_target = source_predict[:, 1], source[:, 1] + other_pred, other_target = source_predict[:, 2], source[:, 2] + vocals_pred, vocals_target = source_predict[:, 3], source[:, 3] + + drums_loss = self.loss(drums_pred, drums_target) + + bass_loss = self.loss(bass_pred, bass_target) + + other_loss = self.loss(other_pred, other_target) + + vocals_loss = self.loss(vocals_pred, vocals_target) + + loss = 0.25 * (drums_loss + bass_loss + other_loss + vocals_loss) + + self.log_dict( + { + "valid_loss": loss, + "valid_drums": drums_loss, + "valid_bass": bass_loss, + "valid_other": other_loss, + "valid_vocals": vocals_loss, + }, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + self.log_dict( + { + "valid_drums_sdr": signal_distortion_ratio( + drums_pred, drums_target + ).mean(), + "valid_bass_sdr": signal_distortion_ratio( + bass_pred, bass_target + ).mean(), + "valid_other_sdr": signal_distortion_ratio( + other_pred, other_target + ).mean(), + "valid_vocals_sdr": signal_distortion_ratio( + vocals_pred, vocals_target + ).mean(), + }, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + + self.log_dict( + { + "valid_drums_sisdr": scale_invariant_signal_distortion_ratio( + drums_pred, drums_target + ).mean(), + "valid_bass_sisdr": scale_invariant_signal_distortion_ratio( + bass_pred, bass_target + ).mean(), + "valid_other_sisdr": scale_invariant_signal_distortion_ratio( + other_pred, other_target + ).mean(), + "valid_vocals_sisdr": scale_invariant_signal_distortion_ratio( + vocals_pred, vocals_target + ).mean(), + }, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + + self.log_dict( + { + "valid_drums_usdr": self.__usdr(drums_pred, drums_target).mean(), + "valid_bass_usdr": self.__usdr(bass_pred, bass_target).mean(), + "valid_other_usdr": self.__usdr(other_pred, other_target).mean(), + "valid_vocals_usdr": self.__usdr(vocals_pred, vocals_target).mean(), + }, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + + def configure_optimizers(self): + optimizer = torch.optim.RAdam(self.parameters(), lr=config.lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=config.T_0 + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": "valid_loss", + } + + +def main(config): + from data.dataset import get_musdb_wav_datasets + from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor + + Path(config.musdb_path).mkdir(exist_ok=True, parents=True) + Path(config.metadata_train_path).mkdir(exist_ok=True, parents=True) + train_set = get_musdb_wav_datasets( + musdb=config.musdb_path, + data_type="train", + metadata=config.metadata_train_path, + segment=config.segment, + ) + + Path(config.metadata_test_path).mkdir(exist_ok=True, parents=True) + test_set = get_musdb_wav_datasets( + musdb=config.musdb_path, + data_type="test", + metadata=config.metadata_test_path, + segment=config.segment, + ) + + train_dl = torch.utils.data.DataLoader( + train_set, + batch_size=config.batch_size, + shuffle=config.shuffle_train, + drop_last=config.drop_last, + num_workers=config.num_workers, + ) + valid_dl = torch.utils.data.DataLoader( + test_set, + batch_size=config.batch_size, + shuffle=config.shuffle_valid, + num_workers=config.num_workers, + ) + + checkpoint_callback = ModelCheckpoint( + monitor="valid_loss", + mode=config.metric_monitor_mode, + save_top_k=config.save_top_k_model_weights, + ) + lr_monitor = LearningRateMonitor(logging_interval="step") + + mp_model = PM_model(config) + + trainer = pl.Trainer( + accelerator="gpu" if config.device == "cuda" else "cpu", + devices="auto", + max_epochs=config.max_epochs, + callbacks=[checkpoint_callback, lr_monitor], + precision=config.precision, + gradient_clip_val=config.grad_clip, + ) + + trainer.fit(mp_model, train_dl, valid_dl) + + +if __name__ == "__main__": + from config.config import TrainConfig + + config = TrainConfig() + main(config) diff --git a/separator/train/augment.py b/separator/train/augment.py new file mode 100644 index 0000000..88ad06d --- /dev/null +++ b/separator/train/augment.py @@ -0,0 +1,361 @@ +import random +import torchaudio +import torch as th +from torch import nn +from torch_audiomentations import PitchShift as ps + + +class Shift(nn.Module): + """ + Shifts audio in time for data augmentation during training. Applies a random shift up to 'shift' samples. + If 'same' is True, all sources in a batch are shifted by the same amount; otherwise, each is shifted differently. + + Args: + proba (float): Probability of applying the shift. + shift (int): Maximum number of samples for the shift. Defaults to 8192. + same (bool): Apply the same shift to all sources in a batch. Defaults to False. + """ + + def __init__(self, proba=1, shift=8192, same=False): + super().__init__() + self.shift = shift + self.same = same + self.proba = proba + + def forward(self, wav): + if self.shift < 1: + return wav + + batch, sources, channels, time = wav.size() + length = time - self.shift + + if random.random() < self.proba: + srcs = 1 if self.same else sources + offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device) + offsets = offsets.expand(-1, sources, channels, -1) + indexes = th.arange(length, device=wav.device) + wav = wav.gather(3, indexes + offsets) + return wav + + +class FlipChannels(nn.Module): + """ + Flip left-right channels. + Args: + proba (float): Probability of applying the flip left-right channels. + """ + + def __init__(self, proba=1): + super().__init__() + self.proba = proba + + def forward(self, wav): + batch, sources, channels, time = wav.size() + if wav.size(2) == 2: + if random.random() < self.proba: + left = th.randint(2, (batch, sources, 1, 1), device=wav.device) + left = left.expand(-1, -1, -1, time) + right = 1 - left + wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) + return wav + + +class FlipSign(nn.Module): + """ + Random sign flip. + Args: + proba (float): Probability of applying the sign flip. + """ + + def __init__(self, proba=1): + super().__init__() + + self.proba = proba + + def forward(self, wav): + batch, sources, channels, time = wav.size() + if random.random() < self.proba: + signs = th.randint( + 2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32 + ) + wav = wav * (2 * signs - 1) + return wav + + +class Remix(nn.Module): + """ + Randomly shuffles sources within each batch during training to create new mixes. Shuffling is performed within groups. + Args: + proba (float): Probability of applying the shuffle. + group_size (int): Size of groups within which shuffling occurs. + """ + + def __init__(self, proba=1, group_size=4): + super().__init__() + self.proba = proba + self.group_size = group_size + + def forward(self, wav): + batch, streams, channels, time = wav.size() + device = wav.device + + if self.training and random.random() < self.proba: + group_size = self.group_size or batch + if batch % group_size != 0: + raise ValueError( + f"Batch size {batch} must be divisible by group size {group_size}" + ) + groups = batch // group_size + wav = wav.view(groups, group_size, streams, channels, time) + permutations = th.argsort( + th.rand(groups, group_size, streams, 1, 1, device=device), dim=1 + ) + wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time)) + wav = wav.view(batch, streams, channels, time) + return wav + + +class Scale(nn.Module): + """ + Scales the amplitude of the audio waveform during training. The scaling factor is chosen randomly within a specified range. + Args: + proba (float): Probability of applying the scaling. + min (float): Minimum scaling factor. + max (float): Maximum scaling factor. + """ + + def __init__(self, proba=1.0, min=0.25, max=1.25): + super().__init__() + self.proba = proba + self.min = min + self.max = max + + def forward(self, wav): + batch, streams, channels, time = wav.size() + device = wav.device + if self.training and random.random() < self.proba: + scales = th.empty(batch, streams, 1, 1, device=device).uniform_( + self.min, self.max + ) + wav *= scales + return wav + + +class FadeMask(nn.Module): + """ + Applies time-domain masking to the spectrogram for data augmentation. + Args: + proba (float): Probability of applying the mask. + sample_rate (int): Sample rate of the audio. + time_mask_param (int): Maximum possible length in seconds of the mask. + """ + + def __init__(self, proba=1, sample_rate=44100, time_mask_param=2): + super().__init__() + self.sample_rate = sample_rate + self.time_mask = torchaudio.transforms.TimeMasking( + time_mask_param=sample_rate * time_mask_param + ) + self.proba = proba + + def forward(self, wav): + if random.random() < self.proba: + wav = wav.clone() + wav[:, 0] = self.time_mask(wav[:, 0]) + wav[:, 1] = self.time_mask(wav[:, 1]) + wav[:, 2] = self.time_mask(wav[:, 2]) + wav[:, 3] = self.time_mask(wav[:, 3]) + + return wav # output -> tensor + + +class PitchShift(nn.Module): # input -> tensor + """ + Applies pitch shifting to audio sources. The pitch is shifted up or down without changing the tempo. + Args: + proba (float): Probability of applying the pitch shift. + min_semitones (int): Min shift for vocal source. + max_semitones (int): Max shift for vocal source. + min_semitones_other (int): Min shift for other sources. + max_semitones_other (int): Max shift for other sources. + sample_rate (int): Sample rate of audio. + flag_other (bool): Apply augmentation to other sources. + """ + + def __init__( + self, + proba=1, + min_semitones=-5, + max_semitones=5, + min_semitones_other=-2, + max_semitones_other=2, + sample_rate=44100, + flag_other=False, + ): + super().__init__() + self.pitch_vocals = ps( + p=proba, + min_transpose_semitones=min_semitones, + max_transpose_semitones=max_semitones, + sample_rate=sample_rate, + ) + + self.flag_other = flag_other + if flag_other: + self.pitch_other = ps( + p=proba, + min_transpose_semitones=min_semitones_other, + max_transpose_semitones=max_semitones_other, + sample_rate=sample_rate, + ) + + def forward(self, wav): + wav = wav.clone() + if self.flag_other: + wav[:, 0] = self.pitch_other(wav[:, 0]) + wav[:, 1] = self.pitch_other(wav[:, 1]) + wav[:, 2] = self.pitch_other(wav[:, 2]) + wav[:, 3] = self.pitch_vocals(wav[:, 3]) + + return wav + + +class TimeChange(nn.Module): + """ + Changes the speed or duration of the signal without affecting the pitch. + Args: + factors_list (list): List of factors to adjust speed. + proba (float): Probability of applying the time change. + sample_rate (int): Sample rate of audio. + """ + + def __init__(self, factors_list, proba=1, sample_rate=44100): + super().__init__() + self.sample_rate = sample_rate + self.proba = proba + + self.time = torchaudio.transforms.SpeedPerturbation( + orig_freq=sample_rate, factors=factors_list + ) + + def forward(self, wav): + if random.random() < self.proba: + wav, _ = self.time(wav) + + return wav + + +class Double(nn.Module): + """ + With equal probability, makes both channels the same as either the left or right original channel. + Args: + proba (float): Probability of applying the doubling. + """ + + def __init__(self, proba=1): + super().__init__() + self.proba = proba + + def forward(self, wav): + if random.random() < self.proba: + wav = wav.clone() + + if random.random() < 0.5: + wav[:, 0][:, 1] = wav[:, 0][:, 0] + wav[:, 1][:, 1] = wav[:, 1][:, 0] + wav[:, 2][:, 1] = wav[:, 2][:, 0] + wav[:, 3][:, 1] = wav[:, 3][:, 0] + else: + wav[:, 0][:, 0] = wav[:, 0][:, 1] + wav[:, 1][:, 0] = wav[:, 1][:, 1] + wav[:, 2][:, 0] = wav[:, 2][:, 1] + wav[:, 3][:, 0] = wav[:, 3][:, 1] + + return wav + + +class Reverse(nn.Module): + """ + Reverses a segment of the vocal source along the time axis. + Args: + proba (float): Probability of applying the reversal. + min_band_part (float): Minimum fraction of the track to be inverted. + max_band_part (float): Maximum fraction of the track to be inverted.""" + + def __init__(self, proba=1, min_band_part=0.2, max_band_part=0.4): + super().__init__() + self.proba = proba + self.min_band_part = min_band_part + self.max_band_part = max_band_part + + def forward(self, wav): + num_samples = wav.shape[-1] + + if random.random() < self.proba: + wav = wav.clone() + + end = random.randint( + int(num_samples * self.min_band_part), + int(num_samples * self.max_band_part), + ) + start = random.randint(0, num_samples - end) + wav[..., start : end + start][:, 3] = th.flip( + wav[..., start : end + start][:, 3], [2] + ) + + return wav + + +class RemixWave(nn.Module): + """ + Creates a mashup track within a batch. + Args: + proba (float): Probability of applying the mashup. + group_size (int): Group size for mashup. + mix_depth (int): Number of tracks to mix. + """ + + def __init__(self, proba=1, group_size=4, mix_depth=2): + super().__init__() + self.proba = proba + self.remix = Remix(proba=1, group_size=group_size) + self.mix_depth = mix_depth + + def forward(self, wav): + if random.random() < self.proba: + mix = wav.clone() + for i in range(self.mix_depth): + mix += self.remix(wav) + return mix + else: + return wav + + +class RemixChannel(nn.Module): + """ + Shuffles source channels within a batch. + Args: + proba (float): Probability of applying the channel shuffle. + """ + + def __init__(self, proba=1): + super().__init__() + + self.proba = proba + + def forward(self, wav): + batch, streams, channels, time = wav.size() + if self.training and random.random() < self.proba: + drums = wav[:, 0].reshape(-1, time) + bass = wav[:, 1].reshape(-1, time) + other = wav[:, 2].reshape(-1, time) + vocals = wav[:, 3].reshape(-1, time) + + s0 = drums[th.randperm(drums.size()[0])].view(batch, 1, 2, time) + s1 = bass[th.randperm(bass.size()[0])].view(batch, 1, 2, time) + s2 = other[th.randperm(other.size()[0])].view(batch, 1, 2, time) + s3 = vocals[th.randperm(vocals.size()[0])].view(batch, 1, 2, time) + + return th.concat((s0, s1, s2, s3), dim=1) + else: + return wav diff --git a/separator/train/loss.py b/separator/train/loss.py new file mode 100644 index 0000000..91f8ee2 --- /dev/null +++ b/separator/train/loss.py @@ -0,0 +1,171 @@ +import warnings +from typing import Dict, Final, Iterable, List, Optional, Union + +import torch +import torch as th +import torch.nn.functional as F +from torch import Tensor, nn +from torch.autograd import Function +import lpips + + +class angle(Function): + """Similar to torch.angle but robustify the gradient for zero magnitude.""" + + @staticmethod + def forward(ctx, x: Tensor): + ctx.save_for_backward(x) + return torch.atan2(x.imag, x.real) + + @staticmethod + def backward(ctx, grad: Tensor): + (x,) = ctx.saved_tensors + grad_inv = grad / (x.real.square() + x.imag.square()).clamp_min_(1e-10) + return torch.view_as_complex( + torch.stack((-x.imag * grad_inv, x.real * grad_inv), dim=-1) + ) + + +class Stft(nn.Module): + def __init__( + self, n_fft: int, hop: Optional[int] = None, window: Optional[Tensor] = None + ): + super().__init__() + self.n_fft = n_fft + self.hop = hop or n_fft // 4 + if window is not None: + assert window.shape[0] == n_fft + else: + window = torch.hann_window(self.n_fft) + self.w: torch.Tensor + self.register_buffer("w", window) + + def forward(self, input: Tensor): + # Time-domain input shape: [B, *, T] + t = input.shape[-1] + sh = input.shape[:-1] + out = torch.stft( + input.reshape(-1, t), + n_fft=self.n_fft, + hop_length=self.hop, + window=self.w, + normalized=True, + return_complex=True, + ) + out = out.view(*sh, *out.shape[-2:]) + return out + + +class SpectralLoss(nn.Module): + """ + Calculates the L1 loss between the target and predicted magnitudes, and between the target and predicted phases. + The total loss is the sum of L1 loss for magnitude and L1 loss for phase: + L1(target magnitude, predicted magnitude) + L1(target phase, predicted phase). + """ + + def __init__(self, n_fft=4096): + super().__init__() + self.stft = Stft(n_fft) + + def magnitude_phase(self, x): + spectr = self.stft(x) + magnitude = spectr.abs() + phase = th.atan2(spectr.imag, spectr.real) + return magnitude, phase + + def forward(self, target, predict): + target_magnitude, target_phase = self.magnitude_phase(target) + predict_magnitude, predict_phase = self.magnitude_phase(predict) + + loss = F.l1_loss(target_magnitude, predict_magnitude) + F.l1_loss( + target_phase, predict_phase + ) + + return loss + + +class MultiResSpecLoss(nn.Module): + """ + Determines the discrepancies between the anticipated and actual spectrogram based on Short-Time Fourier Transform (STFT) + with varying windows, utilizing the Mean Square Error (MSE) loss function for calculation. + We use loss from Deep-FilterNet https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/loss.py#L95 + """ + + gamma: Final[float] + f: Final[float] + f_complex: Final[Optional[List[float]]] + + def __init__( + self, + n_ffts: Iterable[int], + gamma: float = 1, + factor: float = 1, + f_complex: Optional[Union[float, Iterable[float]]] = None, + ): + super().__init__() + self.gamma = gamma + self.f = factor + self.stfts = nn.ModuleDict({str(n_fft): Stft(n_fft) for n_fft in n_ffts}) + if f_complex is None or f_complex == 0: + self.f_complex = None + elif isinstance(f_complex, Iterable): + self.f_complex = list(f_complex) + else: + self.f_complex = [f_complex] * len(self.stfts) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + loss = torch.zeros((), device=input.device, dtype=input.dtype) + for i, stft in enumerate(self.stfts.values()): + Y = stft(input) + S = stft(target) + Y_abs = Y.abs() + S_abs = S.abs() + if self.gamma != 1: + Y_abs = Y_abs.clamp_min(1e-12).pow(self.gamma) + S_abs = S_abs.clamp_min(1e-12).pow(self.gamma) + loss += F.mse_loss(Y_abs, S_abs) * self.f # mse_loss + if self.f_complex is not None: + if self.gamma != 1: + Y = Y_abs * torch.exp(1j * angle.apply(Y)) + S = S_abs * torch.exp(1j * angle.apply(S)) + loss += ( + F.mse_loss(torch.view_as_real(Y), torch.view_as_real(S)) + * self.f_complex[i] + ) # mse_loss + return loss + + +class PerceptualLoss(nn.Module): + """ + States for perceptual loss and is used to compare high-level features in spectrogram. + Lpips loss and compute L2-distance between trained VGG + https://github.com/richzhang/PerceptualSimilarity + """ + + def __init__(self, net="vgg"): + super().__init__() + self.loss = lpips.LPIPS(net=net) + for param in self.loss.parameters(): + param.requires_grad = False + + def compute_loss(self, source_target, source_predict): + add_channel = torch.randint(0, 2, (1,)) + + source_target_3channel = torch.concat( + [source_target, source_target[:, add_channel]], dim=1 + ) + source_predict_3channel = torch.concat( + [source_predict, source_predict[:, add_channel]], dim=1 + ) + + return self.loss(source_target_3channel, source_predict_3channel) + + def forward(self, target, predict): + loss_drums = self.compute_loss(target[:, 0], predict[:, 0]) + loss_bass = self.compute_loss(target[:, 1], predict[:, 1]) + loss_other = self.compute_loss(target[:, 2], predict[:, 2]) + loss_vocals = self.compute_loss(target[:, 3], predict[:, 3]) + loss = ( + loss_drums.sum() + loss_bass.sum() + loss_other.sum() + loss_vocals.sum() + ) / 4 + return loss diff --git a/streaming/config/config.py b/streaming/config/config.py new file mode 100644 index 0000000..6deead4 --- /dev/null +++ b/streaming/config/config.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class ConverterConfig: + # WEIGHTS LOAD + weights_dir : Path = Path("/app/streaming/weights") # Path to the directory where the model weight files are stored. + weights_LSTM_filename : str = "weight_LSTM.pt" # This is the filename for the LSTM weights file. + weights_conv_filename : str = "weight_conv.pt" # This is the filename for the without CNN weights file. + gdrive_weights_LSTM_id : str = "1uhAVMvW3x-KL2T2-VkjKjn9K7dTJnoyo" # This is the Google Drive ID for the LSTM weights file. + gdrive_weights_conv_id : str = "1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # This is the Google Drive ID for the CNN weights file. + + # MODEL OPTIONS + original_model_src : str = "/app/separator/model" # This parameter represents the source directory of the original model. + original_model_dst : str = "/app/streaming/model" # This parameter represents the destination directory of the original model. + model_py_module : str = "model.PM_Unet" # This is the python module where the model is defined + model_class_name : str = "Model_Unet" # The name of the model class. + tflite_model_dst : str = "/app/streaming/tflite_model" # This is the destination directory for the TFLite model. + sample_rate : int = 44100 # Sample rate track + segment_duration : float = 1.0 # This parameter represents the duration of the audio segments that the model will process. + + +@dataclass +class StreamConfig: + # STREAM OPTIONS + converter_script : str = "/app/streaming/converter.py" # Path to the script used to convert the pytorch model to tflite. + sample_rate : int = 44100 # Sample rate track. + nfft : int = 4096 # Number of bins used in STFT. + stft_py_module : str = "model.STFT" # Path to the script STFT. + default_input_dir : str = "/app/streaming/input" # Path to the directory where the input files are stored. + default_result_dir : str = "/app/streaming/streams" # Path directory in which processing results are saved. + gdrive_mix_id : str = "1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" # The Google Drive ID for the mix file. + default_duration : int = 15 # Length of an audio stream, in seconds. diff --git a/streaming/converter.py b/streaming/converter.py new file mode 100644 index 0000000..8b5de6c --- /dev/null +++ b/streaming/converter.py @@ -0,0 +1,288 @@ +import argparse +import gdown +import importlib +import numbers +import shutil +from typing import Tuple +from pathlib import Path + +import torch +import tensorflow as tf +from tensorflow.lite.python.lite import TFLiteConverter + +import nobuco +from nobuco import ChannelOrder, ChannelOrderingStrategy +from nobuco.layers.weight import WeightLayer + + +@nobuco.converter( + torch.nn.functional.glu, + channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS, +) +def torch_glu(input, dim=1): + def tf_glu(input, dim=1): + ax = -1 if dim == 1 else dim + out, gate = tf.split(input, 2, axis=ax) + gate = tf.sigmoid(gate) + return tf.multiply(out, gate) + + return lambda input, dim=1: tf_glu(input, dim) + + +@nobuco.converter(torch.Tensor.all) +def tensor_all(input: torch.Tensor): + return lambda input: tf.math.reduce_all(input) + + +@nobuco.converter(torch.atan2) +def atan2(input_x: torch.Tensor, input_y: torch.Tensor): + return lambda input_x, input_y: tf.math.atan2(input_x, input_y) + + +@nobuco.converter(torch.Tensor.std) +def tensor_std(input: torch.Tensor, dim, keepdim): + return lambda input, dim, keepdim: tf.math.reduce_std( + input, axis=dim, keepdims=keepdim + ) + + +@nobuco.converter(torch.Tensor.t) +def tensor_t(input: torch.Tensor): + return lambda input: tf.transpose(input) + + +@nobuco.converter(torch.Tensor.__getattribute__) +def torch_getattribute_complex_resolve(input: torch.Tensor, attr: str): + def tf_complex_getattribute(input, attr): + if attr == "real": + tf_func = tf.math.real + elif attr == "imag": + tf_func = tf.math.imag + else: + tf_func = lambda x: getattr(x, attr) + return tf_func(input) + + return tf_complex_getattribute + + +@nobuco.converter(torch.nn.Conv2d) +def converter_Conv2d(self, input: torch.Tensor): + weight = self.weight + bias = self.bias + groups = self.groups + padding = self.padding + stride = self.stride + dilation = self.dilation + + out_filters, in_filters, kh, kw = weight.shape + + weights = weight.cpu().detach().numpy() + weights = tf.transpose(weights, (2, 3, 1, 0)) + + if bias is not None: + biases = bias.cpu().detach().numpy() + params = [weights, biases] + use_bias = True + else: + params = [weights] + use_bias = False + + if padding != 0 and padding != (0, 0) and padding != "valid" and padding != "same": + pad_layer = tf.keras.layers.ZeroPadding2D(padding) + else: + pad_layer = None + + pad_arg = padding if padding == "same" else "valid" + conv = tf.keras.layers.Conv2D( + filters=out_filters, + kernel_size=(kh, kw), + strides=stride, + padding=pad_arg, + dilation_rate=dilation, + groups=groups, + use_bias=use_bias, + weights=params, + ) + + def func(input): + if pad_layer is not None: + input = pad_layer(input) + output = conv(input) + return output + + return func + + +@nobuco.converter(torch.nn.Conv1d) +def converter_Conv1d(self, input: torch.Tensor): + weight = self.weight + bias = self.bias + groups = self.groups + padding = self.padding + stride = self.stride + dilation = self.dilation + + out_filters, in_filters, kw = weight.shape + weights = weight.cpu().detach().numpy() + weights = tf.transpose(weights, (2, 1, 0)) + + if bias is not None: + biases = bias.cpu().detach().numpy() + params = [weights, biases] + use_bias = True + else: + params = [weights] + use_bias = False + + if isinstance(padding, numbers.Number): + padding = (padding,) + if padding != (0,) and padding != "valid" and padding != "same": + pad_layer = tf.keras.layers.ZeroPadding1D(padding[0]) + else: + pad_layer = None + + pad_arg = padding if padding == "same" else "valid" + conv = tf.keras.layers.Conv1D( + filters=out_filters, + kernel_size=kw, + strides=stride, + padding=pad_arg, + dilation_rate=dilation, + groups=groups, + use_bias=use_bias, + weights=params, + ) + + def func(input): + if pad_layer is not None: + input = pad_layer(input) + output = conv(input) + return output + + return func + + +@nobuco.converter( + torch.concat, + channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS, +) +def converter_concat(tensors: Tuple[torch.Tensor], dim): + def tf_concat(tensors, dim): + return tf.concat(list(tensors), -dim) + + return tf_concat + + +def main(args, config): + shutil.copytree( + config.original_model_src, config.original_model_dst, dirs_exist_ok=True + ) + py_module = importlib.import_module(args.model_py_module) + cls_model = getattr(py_module, args.class_name) + model = cls_model( + source=["drums", "bass", "other", "vocals"], + depth=4, + channel=28, + bottlneck_lstm=False, + stft_flag=False, + ) + + if model.bottlneck_lstm: + weights_path = config.weights_dir / config.weights_LSTM_filename + gdrive_id = config.gdrive_weights_LSTM_id + else: + weights_path = config.weights_dir / config.weights_conv_filename + gdrive_id = config.gdrive_weights_conv_id + try: + config.weights_dir.mkdir(parents=True, exist_ok=False) + download_weights = True + except FileExistsError: + try: + weights_path.touch(exist_ok=False) + download_weights = True + except FileExistsError: + download_weights = False + if download_weights: + gdown.download(id=gdrive_id, output=str(weights_path)) + + model.load_state_dict( + torch.load(str(weights_path), map_location=torch.device("cpu")) + ) + + model = model.eval() + + class OuterSTFT: + def __init__(self, length_wave, model): + self.length_wave = length_wave + self.model = model + + def stft(self, wave): + return self.model.stft.stft(wave) + + def istft(self, z): + return self.model.stft.istft(z, self.length_wave) + + SEGMENT_WAVE = int(config.sample_rate * config.segment_duration) + dummy_wave = torch.rand(size=(1, 2, SEGMENT_WAVE)) + dummy_spectr = OuterSTFT(SEGMENT_WAVE, model).stft(dummy_wave) + + keras_model = nobuco.pytorch_to_keras( + model, + args=[dummy_spectr], + kwargs=None, + inputs_channel_order=ChannelOrder.PYTORCH, + ) + + model_filename = f"{args.class_name}_outer_stft_{config.segment_duration:.1f}" + model_path = f"{args.out_dir}/{model_filename}" + try: + Path(args.out_dir).mkdir(exist_ok=False) + except (OSError, FileExistsError): + pass + + keras_model.save(f"{model_path}.h5") + custom_objects = {"WeightLayer": WeightLayer} + + converter = TFLiteConverter.from_keras_model_file( + f"{model_path}.h5", custom_objects=custom_objects + ) + converter.target_ops = [ + tf.lite.OpsSet.SELECT_TF_OPS, + tf.lite.OpsSet.TFLITE_BUILTINS, + ] + tflite_model = converter.convert() + + with open(f"{model_path}.tflite", "wb") as f: + f.write(tflite_model) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Converter script") + from config.config import ConverterConfig + + config = ConverterConfig() + + parser.add_argument( + "-I", + dest="model_py_module", + help="py module of model\nformat: pkg.mod e.g model.PM_Unet", + default=config.model_py_module, + type=str, + ) + parser.add_argument( + "-C", + dest="class_name", + help="class name of nn.Module", + default=config.model_class_name, + type=str, + ) + parser.add_argument( + "-O", + dest="out_dir", + help="specified output dir", + default=config.tflite_model_dst, + type=str, + ) + + args = parser.parse_args() + main(args, config) diff --git a/streaming/runner.py b/streaming/runner.py new file mode 100644 index 0000000..e6c91da --- /dev/null +++ b/streaming/runner.py @@ -0,0 +1,148 @@ +import argparse +import gdown +import logging +import os +import re +import subprocess as sb +import sys +from pathlib import Path + +from tf_lite_stream import TFLiteTorchStream + + +LOGGER = logging.getLogger(__name__) + + +def resolve_default_sample(config): + default_input_dir = config.StreamConfig.default_input_dir + Path(default_input_dir).mkdir(parents=True, exist_ok=True) + + default_sample_path = f"{default_input_dir}/sample.wav" + try: + Path(default_sample_path).touch(exist_ok=False) + gdown.download(id=config.StreamConfig.gdrive_mix_id, output=default_sample_path) + except FileExistsError: + pass + + return default_sample_path + + +def resolve_tflite_model(config): + try: + Path(config.ConverterConfig.tflite_model_dst).mkdir(exist_ok=False) + start_converter = True + except (OSError, FileExistsError): + if len(os.listdir(config.ConverterConfig.tflite_model_dst)) == 0: + start_converter = True + else: + start_converter = False + + if start_converter: + with sb.Popen( + ["python3", config.StreamConfig.converter_script], + stdout=sb.PIPE, + stderr=sb.STDOUT, + ) as proc: + LOGGER.info(proc.stdout.read().decode()) + res = proc.wait() + LOGGER.info( + f"{config.StreamConfig.converter_script} finished with code : {res}" + ) + + converter_outputs = os.listdir(config.ConverterConfig.tflite_model_dst) + converter_outputs = list( + filter(lambda x: re.match(r".*_outer_stft_.*\.tflite$", x), converter_outputs) + ) + converter_outputs = [ + f"{config.ConverterConfig.tflite_model_dst}/{filename}" + for filename in converter_outputs + ] + converter_outputs.sort(key=lambda x: os.stat(x).st_mtime, reverse=True) + tflite_model_path = converter_outputs[0] + parsed_segment = re.findall(r"_outer_stft_(.*)\.tflite$", tflite_model_path)[0] + + return tflite_model_path, parsed_segment + + +def main(args, config): + is_tflite_model_path_default = ( + args.tflite_model_path == config.ConverterConfig.tflite_model_dst + ) + if not is_tflite_model_path_default and not args.tflite_model_segment: + raise ValueError( + "Specify segment [-s (0.5, 1, ...)] of STFT to outer tflite model" + ) + + if is_tflite_model_path_default: + tflite_model_path, parsed_segment = resolve_tflite_model(config) + else: + tflite_model_path, parsed_segment = ( + args.tflite_model_path, + args.tflite_model_segment, + ) + + track_path = args.mix_path + if args.mix_path == config.StreamConfig.default_input_dir: + track_path = resolve_default_sample(config) + + stream_class = TFLiteTorchStream( + config, tflite_model_path, segment=float(parsed_segment) + ) + out_paths = stream_class(track_path, args.out_dir, args.duration) + LOGGER.info("Streams stored in : " + " ".join(out_paths)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Runner script") + from config import config + + parser.add_argument( + "-I", + dest="mix_path", + help="path to mixture", + default=config.StreamConfig.default_input_dir, + type=str, + ) + parser.add_argument( + "-O", + dest="out_dir", + help="specified output dir", + default=config.StreamConfig.default_result_dir, + type=str, + ) + parser.add_argument( + "-d", + dest="duration", + help="specified first seconds to process", + default=config.StreamConfig.default_duration, + type=int, + ) + parser.add_argument( + "-m", + dest="tflite_model_path", + help="path to tflite model", + default=config.ConverterConfig.tflite_model_dst, + type=str, + ) + parser.add_argument( + "-s", + dest="tflite_model_segment", + help="tflite model STFT window width (sample_rate * segment)", + required=False, + type=float, + ) + + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter( + logging.Formatter( + "[%(levelname)s](%(filename)s).%(funcName)s(%(lineno)d) - %(message)s" + ) + ) + logging.basicConfig( + level=logging.DEBUG, + handlers=[stdout_handler], + format="%(levelname)s : %(message)s", + ) + + args = parser.parse_args() + main(args, config) diff --git a/streaming/tf_lite_stream.py b/streaming/tf_lite_stream.py new file mode 100644 index 0000000..73a826b --- /dev/null +++ b/streaming/tf_lite_stream.py @@ -0,0 +1,103 @@ +import importlib +import os +from pathlib import Path +import shutil +import tensorflow as tf +import torch +import torchaudio +from torchaudio.io import StreamReader, StreamWriter +from tqdm import tqdm +from typing import Union + + +class TFLiteTorchStream: + NUM_CHANNELS = 2 + + def __init__(self, config, model_filename: str, segment: float = 1): + self.__interpreter = tf.lite.Interpreter(model_path=model_filename) + self.__interpreter.allocate_tensors() + self.__input_details = self.__interpreter.get_input_details() + self.__output_details = self.__interpreter.get_output_details() + + self.nfft = config.StreamConfig.nfft + self.hop_length = self.nfft // 4 + self.sample_rate = config.StreamConfig.sample_rate + self.segment = segment + + try: + Path(config.ConverterConfig.original_model_dst).mkdir(exist_ok=False) + except FileExistsError: + shutil.rmtree(config.ConverterConfig.original_model_dst) + shutil.copytree( + config.ConverterConfig.original_model_src, + config.ConverterConfig.original_model_dst, + ) + py_module = importlib.import_module(config.StreamConfig.stft_py_module) + cls_stft = getattr(py_module, "STFT") + self.stft = cls_stft(self.nfft) + + def __call__( + self, track_path: str, out_dir: str, duration: Union[int, None] = None + ): + _, sample_rate = torchaudio.load(track_path) + if sample_rate != self.sample_rate: + raise ValueError(f"Non supported sample_rate of {track_path=}") + + stream_mix = StreamReader(src=track_path) + frames_per_chunk = int(44100 * self.segment) + stream_mix.add_basic_audio_stream( + frames_per_chunk=frames_per_chunk, sample_rate=44100 + ) + + try: + os.mkdir(out_dir) + except OSError as error: + pass + + out_paths = ( + f"{out_dir}/drums.wav", + f"{out_dir}/bass.wav", + f"{out_dir}/other.wav", + f"{out_dir}/vocals.wav", + ) + + stream_drums = StreamWriter(dst=out_paths[0]) + stream_bass = StreamWriter(dst=out_paths[1]) + stream_other = StreamWriter(dst=out_paths[2]) + stream_vocals = StreamWriter(dst=out_paths[3]) + + stream_drums.add_audio_stream(sample_rate, TFLiteTorchStream.NUM_CHANNELS) + stream_bass.add_audio_stream(sample_rate, TFLiteTorchStream.NUM_CHANNELS) + stream_other.add_audio_stream(sample_rate, TFLiteTorchStream.NUM_CHANNELS) + stream_vocals.add_audio_stream(sample_rate, TFLiteTorchStream.NUM_CHANNELS) + + chunk_count = int(sample_rate * duration // frames_per_chunk) if duration else 0 + with stream_drums.open(), stream_bass.open(), stream_other.open(), stream_vocals.open(): + for i, chunk in tqdm(enumerate(stream_mix.stream())): + if duration and i > chunk_count: + break + processed_chunk = (chunk[0].T)[None] + if processed_chunk.shape[-1] != int(44100 * self.segment): + continue + with torch.no_grad(): + out = self.model_call(processed_chunk).permute(0, 1, 3, 2) + + stream_drums.write_audio_chunk(0, out[0][0]) + stream_bass.write_audio_chunk(0, out[0][1]) + stream_other.write_audio_chunk(0, out[0][2]) + stream_vocals.write_audio_chunk(0, out[0][3]) + + return out_paths + + def model_call(self, mix): + length = mix.shape[-1] + in_spectr = self.stft.stft(mix) + + self.__interpreter.set_tensor(self.__input_details[0]["index"], in_spectr) + self.__interpreter.invoke() + + out_spectr_tf = self.__interpreter.get_tensor(self.__output_details[0]["index"]) + out_spectr = torch.tensor(out_spectr_tf, dtype=torch.cfloat) + + ret = self.stft.istft(out_spectr, length) + return ret