diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..74d492c --- /dev/null +++ b/evaluate.py @@ -0,0 +1,142 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import argparse +import numpy as np +import torch as th +import torchaudio as ta + +from src.models import BinauralNetwork +from src.losses import L2Loss, AmplitudeLoss, PhaseLoss + + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset_directory", + type=str, + default="./data/testset", + help="path to the test data") +parser.add_argument("--model_file", + type=str, + default="./outputs/binaural_network.net", + help="model file containing the trained binaural network weights") +parser.add_argument("--artifacts_directory", + type=str, + default="./outputs", + help="directory to write binaural outputs to") +args = parser.parse_args() + + +def chunked_forwarding(net, mono, view): + ''' + binauralized the mono input given the view + :param net: binauralization network + :param mono: 1 x T tensor containing the mono audio signal + :param view: 7 x K tensor containing the view as 3D positions and quaternions for orientation (K = T / 400) + :return: 2 x T tensor containing binauralized audio signal + ''' + net.eval().cuda() + mono, view = mono.cuda(), view.cuda() + + chunk_size = 480000 # forward in chunks of 10s + rec_field = net.receptive_field() + 1000 # add 1000 samples as "safe bet" since warping has undefined rec. field + rec_field -= rec_field % 400 # make sure rec_field is a multiple of 400 to match audio and view frequencies + chunks = [ + { + "mono": mono[:, max(0, i-rec_field):i+chunk_size], + "view": view[:, max(0, i-rec_field)//400:(i+chunk_size)//400] + } + for i in range(0, mono.shape[-1], chunk_size) + ] + + for i, chunk in enumerate(chunks): + with th.no_grad(): + mono = chunk["mono"].unsqueeze(0) + view = chunk["view"].unsqueeze(0) + binaural = net(mono, view)["output"].squeeze(0) + if i > 0: + binaural = binaural[:, -(mono.shape[-1]-rec_field):] + chunk["binaural"] = binaural + + binaural = th.cat([chunk["binaural"] for chunk in chunks], dim=-1) + binaural = th.clamp(binaural, min=-1, max=1).cpu() + return binaural + + +def compute_metrics(binauralized, reference): + ''' + compute l2 error, amplitude error, and angular phase error for the given binaural and reference singal + :param binauralized: 2 x T tensor containing predicted binaural signal + :param reference: 2 x T tensor containing reference binaural signal + :return: errors as a scalar value for each metric and the number of samples in the sequence + ''' + binauralized, reference = binauralized.unsqueeze(0), reference.unsqueeze(0) + + # compute error metrics + l2_error = L2Loss()(binauralized, reference) + amplitude_error = AmplitudeLoss(sample_rate=48000)(binauralized, reference) + phase_error = PhaseLoss(sample_rate=48000, ignore_below=0.2)(binauralized, reference) + + return{ + "l2": l2_error, + "amplitude": amplitude_error, + "phase": phase_error, + "samples": binauralized.shape[-1] + } + + +# binauralized and evaluate test sequence for the eight subjects and the validation sequence +test_sequences = [f"subject{i+1}" for i in range(8)] + ["validation_sequence"] + +# initialize network +net = BinauralNetwork(view_dim=7, + warpnet_layers=4, + warpnet_channels=64, + wavenet_blocks=3, + layers_per_block=10, + wavenet_channels=64 + ) +net.load_from_file(args.model_file) + +os.makedirs(f"{args.artifacts_directory}", exist_ok=True) + +errors = [] +for test_sequence in test_sequences: + print(f"binauralize {test_sequence}...") + + # load mono input and view conditioning + mono, sr = ta.load(f"{args.dataset_directory}/{test_sequence}/mono.wav") + view = np.loadtxt(f"{args.dataset_directory}/{test_sequence}/tx_positions.txt").transpose().astype(np.float32) + view = th.from_numpy(view) + + # sanity checks + if not sr == 48000: + raise Exception(f"sampling rate is expected to be 48000 but is {sr}.") + if not view.shape[-1] * 400 == mono.shape[-1]: + raise Exception(f"mono signal is expected to have 400x the length of the position/orientation sequence.") + + # binauralize and save output + binaural = chunked_forwarding(net, mono, view) + ta.save(f"{args.artifacts_directory}/{test_sequence}.wav", binaural, sr) + + # compute error metrics + reference, sr = ta.load(f"{args.dataset_directory}/{test_sequence}/binaural.wav") + errors.append(compute_metrics(binaural, reference)) + +# accumulate errors +sequence_weights = np.array([err["samples"] for err in errors]) +sequence_weights = sequence_weights / np.sum(sequence_weights) +l2_error = sum([err["l2"] * sequence_weights[i] for i, err in enumerate(errors)]) +amplitude_error = sum([err["amplitude"] * sequence_weights[i] for i, err in enumerate(errors)]) +phase_error = sum([err["phase"] * sequence_weights[i] for i, err in enumerate(errors)]) + +# print accumulated errors on testset +print(f"l2 (x10^3): {l2_error * 1000:.3f}") +print(f"amplitude: {amplitude_error:.3f}") +print(f"phase: {phase_error:.3f}") + diff --git a/src/losses.py b/src/losses.py index d2172bb..55532f9 100644 --- a/src/losses.py +++ b/src/losses.py @@ -44,7 +44,7 @@ def _loss(self, data, target): return th.mean((data - target).pow(2)) -class PhaseLoss(Loss): +class AmplitudeLoss(Loss): def __init__(self, sample_rate, mask_beginning=0): ''' :param sample_rate: (int) sample rate of the audio signal @@ -56,6 +56,31 @@ def __init__(self, sample_rate, mask_beginning=0): def _transform(self, data): return self.fft.stft(data.view(-1, data.shape[-1])) + def _loss(self, data, target): + ''' + :param data: predicted wave signals in a B x channels x T tensor + :param target: target wave signals in a B x channels x T tensor + :return: a scalar loss value + ''' + data, target = self._transform(data), self._transform(target) + data = th.sum(data**2, dim=-1) ** 0.5 + target = th.sum(target**2, dim=-1) ** 0.5 + return th.mean(th.abs(data - target)) + + +class PhaseLoss(Loss): + def __init__(self, sample_rate, mask_beginning=0, ignore_below=0.1): + ''' + :param sample_rate: (int) sample rate of the audio signal + :param mask_beginning: (int) number of samples to mask at the beginning of the signal + ''' + super().__init__(mask_beginning) + self.ignore_below = ignore_below + self.fft = FourierTransform(sample_rate=sample_rate) + + def _transform(self, data): + return self.fft.stft(data.view(-1, data.shape[-1])) + def _loss(self, data, target): ''' :param data: predicted wave signals in a B x channels x T tensor @@ -66,8 +91,8 @@ def _loss(self, data, target): # ignore low energy components for numerical stability target_energy = th.sum(th.abs(target), dim=-1) pred_energy = th.sum(th.abs(data.detach()), dim=-1) - target_mask = target_energy > 0.1 * th.mean(target_energy) - pred_mask = pred_energy > 0.1 * th.mean(target_energy) + target_mask = target_energy > self.ignore_below * th.mean(target_energy) + pred_mask = pred_energy > self.ignore_below * th.mean(target_energy) indices = th.nonzero(target_mask * pred_mask).view(-1) data, target = th.index_select(data, 0, indices), th.index_select(target, 0, indices) # compute actual phase loss in angular space diff --git a/src/utils.py b/src/utils.py index 1f41010..b125305 100644 --- a/src/utils.py +++ b/src/utils.py @@ -36,30 +36,32 @@ def save(self, model_dir, suffix=''): if self.use_cuda: self.cuda() - ''' - load network parameters from model_dir/model_name.suffix.net - model_dir: (str) directory where the model should be stored - suffix: (str) optional suffix to append to the network name - ''' + def load_from_file(self, model_file): + ''' + load network parameters from model_file + :param model_file: file containing the model parameters + ''' + if self.use_cuda: + self.cpu() + + states = th.load(model_file) + self.load_state_dict(states) + + if self.use_cuda: + self.cuda() + print(f"Loaded: {model_file}") + def load(self, model_dir, suffix=''): ''' load network parameters from model_dir/model_name.suffix.net :param model_dir: directory to load the model from :param suffix: suffix to append after model name ''' - if self.use_cuda: - self.cpu() - if suffix == "": fname = f"{model_dir}/{self.model_name}.net" else: fname = f"{model_dir}/{self.model_name}.{suffix}.net" - - states = th.load(fname) - self.load_state_dict(states) - if self.use_cuda: - self.cuda() - print("Loaded:", fname) + self.load_from_file(fname) def num_trainable_parameters(self): ''' diff --git a/train.py b/train.py index 5390353..49f5f24 100644 --- a/train.py +++ b/train.py @@ -7,18 +7,29 @@ """ import os +import argparse from src.dataset import BinauralDataset from src.models import BinauralNetwork from src.trainer import Trainer -dataset_dir = "/mnt/home/richardalex/tmp/bobatea/data/trainset" -artifacts_dir = "/mnt/home/richardalex/tmp/artifacts" - -os.makedirs(artifacts_dir, exist_ok=True) +parser = argparse.ArgumentParser() +parser.add_argument("--dataset_directory", + type=str, + default="./data/trainset", + help="path to the training data") +parser.add_argument("--artifacts_directory", + type=str, + default="./outputs", + help="directory to write model files to") +parser.add_argument("--num_gpus", + type=int, + default=4, + help="number of GPUs used during training") +args = parser.parse_args() config = { - "artifacts_dir": artifacts_dir, + "artifacts_dir": args.artifacts_directory, "learning_rate": 0.001, "newbob_decay": 0.5, "newbob_max_decay": 0.01, @@ -27,10 +38,12 @@ "loss_weights": {"l2": 1.0, "phase": 0.01}, "save_frequency": 10, "epochs": 100, - "num_gpus": 4, + "num_gpus": args.num_gpus, } -dataset = BinauralDataset(dataset_directory=dataset_dir, chunk_size_ms=200, overlap=0.5) +os.makedirs(config["artifacts_dir"], exist_ok=True) + +dataset = BinauralDataset(dataset_directory=args.dataset_directory, chunk_size_ms=200, overlap=0.5) net = BinauralNetwork(view_dim=7, warpnet_layers=4,