Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
added evaluation script
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Richard committed Apr 5, 2021
1 parent 4b031ee commit 8d1b858
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 24 deletions.
142 changes: 142 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import os
import argparse
import numpy as np
import torch as th
import torchaudio as ta

from src.models import BinauralNetwork
from src.losses import L2Loss, AmplitudeLoss, PhaseLoss


parser = argparse.ArgumentParser()
parser.add_argument("--dataset_directory",
type=str,
default="./data/testset",
help="path to the test data")
parser.add_argument("--model_file",
type=str,
default="./outputs/binaural_network.net",
help="model file containing the trained binaural network weights")
parser.add_argument("--artifacts_directory",
type=str,
default="./outputs",
help="directory to write binaural outputs to")
args = parser.parse_args()


def chunked_forwarding(net, mono, view):
'''
binauralized the mono input given the view
:param net: binauralization network
:param mono: 1 x T tensor containing the mono audio signal
:param view: 7 x K tensor containing the view as 3D positions and quaternions for orientation (K = T / 400)
:return: 2 x T tensor containing binauralized audio signal
'''
net.eval().cuda()
mono, view = mono.cuda(), view.cuda()

chunk_size = 480000 # forward in chunks of 10s
rec_field = net.receptive_field() + 1000 # add 1000 samples as "safe bet" since warping has undefined rec. field
rec_field -= rec_field % 400 # make sure rec_field is a multiple of 400 to match audio and view frequencies
chunks = [
{
"mono": mono[:, max(0, i-rec_field):i+chunk_size],
"view": view[:, max(0, i-rec_field)//400:(i+chunk_size)//400]
}
for i in range(0, mono.shape[-1], chunk_size)
]

for i, chunk in enumerate(chunks):
with th.no_grad():
mono = chunk["mono"].unsqueeze(0)
view = chunk["view"].unsqueeze(0)
binaural = net(mono, view)["output"].squeeze(0)
if i > 0:
binaural = binaural[:, -(mono.shape[-1]-rec_field):]
chunk["binaural"] = binaural

binaural = th.cat([chunk["binaural"] for chunk in chunks], dim=-1)
binaural = th.clamp(binaural, min=-1, max=1).cpu()
return binaural


def compute_metrics(binauralized, reference):
'''
compute l2 error, amplitude error, and angular phase error for the given binaural and reference singal
:param binauralized: 2 x T tensor containing predicted binaural signal
:param reference: 2 x T tensor containing reference binaural signal
:return: errors as a scalar value for each metric and the number of samples in the sequence
'''
binauralized, reference = binauralized.unsqueeze(0), reference.unsqueeze(0)

# compute error metrics
l2_error = L2Loss()(binauralized, reference)
amplitude_error = AmplitudeLoss(sample_rate=48000)(binauralized, reference)
phase_error = PhaseLoss(sample_rate=48000, ignore_below=0.2)(binauralized, reference)

return{
"l2": l2_error,
"amplitude": amplitude_error,
"phase": phase_error,
"samples": binauralized.shape[-1]
}


# binauralized and evaluate test sequence for the eight subjects and the validation sequence
test_sequences = [f"subject{i+1}" for i in range(8)] + ["validation_sequence"]

# initialize network
net = BinauralNetwork(view_dim=7,
warpnet_layers=4,
warpnet_channels=64,
wavenet_blocks=3,
layers_per_block=10,
wavenet_channels=64
)
net.load_from_file(args.model_file)

os.makedirs(f"{args.artifacts_directory}", exist_ok=True)

errors = []
for test_sequence in test_sequences:
print(f"binauralize {test_sequence}...")

# load mono input and view conditioning
mono, sr = ta.load(f"{args.dataset_directory}/{test_sequence}/mono.wav")
view = np.loadtxt(f"{args.dataset_directory}/{test_sequence}/tx_positions.txt").transpose().astype(np.float32)
view = th.from_numpy(view)

# sanity checks
if not sr == 48000:
raise Exception(f"sampling rate is expected to be 48000 but is {sr}.")
if not view.shape[-1] * 400 == mono.shape[-1]:
raise Exception(f"mono signal is expected to have 400x the length of the position/orientation sequence.")

# binauralize and save output
binaural = chunked_forwarding(net, mono, view)
ta.save(f"{args.artifacts_directory}/{test_sequence}.wav", binaural, sr)

# compute error metrics
reference, sr = ta.load(f"{args.dataset_directory}/{test_sequence}/binaural.wav")
errors.append(compute_metrics(binaural, reference))

# accumulate errors
sequence_weights = np.array([err["samples"] for err in errors])
sequence_weights = sequence_weights / np.sum(sequence_weights)
l2_error = sum([err["l2"] * sequence_weights[i] for i, err in enumerate(errors)])
amplitude_error = sum([err["amplitude"] * sequence_weights[i] for i, err in enumerate(errors)])
phase_error = sum([err["phase"] * sequence_weights[i] for i, err in enumerate(errors)])

# print accumulated errors on testset
print(f"l2 (x10^3): {l2_error * 1000:.3f}")
print(f"amplitude: {amplitude_error:.3f}")
print(f"phase: {phase_error:.3f}")

31 changes: 28 additions & 3 deletions src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _loss(self, data, target):
return th.mean((data - target).pow(2))


class PhaseLoss(Loss):
class AmplitudeLoss(Loss):
def __init__(self, sample_rate, mask_beginning=0):
'''
:param sample_rate: (int) sample rate of the audio signal
Expand All @@ -56,6 +56,31 @@ def __init__(self, sample_rate, mask_beginning=0):
def _transform(self, data):
return self.fft.stft(data.view(-1, data.shape[-1]))

def _loss(self, data, target):
'''
:param data: predicted wave signals in a B x channels x T tensor
:param target: target wave signals in a B x channels x T tensor
:return: a scalar loss value
'''
data, target = self._transform(data), self._transform(target)
data = th.sum(data**2, dim=-1) ** 0.5
target = th.sum(target**2, dim=-1) ** 0.5
return th.mean(th.abs(data - target))


class PhaseLoss(Loss):
def __init__(self, sample_rate, mask_beginning=0, ignore_below=0.1):
'''
:param sample_rate: (int) sample rate of the audio signal
:param mask_beginning: (int) number of samples to mask at the beginning of the signal
'''
super().__init__(mask_beginning)
self.ignore_below = ignore_below
self.fft = FourierTransform(sample_rate=sample_rate)

def _transform(self, data):
return self.fft.stft(data.view(-1, data.shape[-1]))

def _loss(self, data, target):
'''
:param data: predicted wave signals in a B x channels x T tensor
Expand All @@ -66,8 +91,8 @@ def _loss(self, data, target):
# ignore low energy components for numerical stability
target_energy = th.sum(th.abs(target), dim=-1)
pred_energy = th.sum(th.abs(data.detach()), dim=-1)
target_mask = target_energy > 0.1 * th.mean(target_energy)
pred_mask = pred_energy > 0.1 * th.mean(target_energy)
target_mask = target_energy > self.ignore_below * th.mean(target_energy)
pred_mask = pred_energy > self.ignore_below * th.mean(target_energy)
indices = th.nonzero(target_mask * pred_mask).view(-1)
data, target = th.index_select(data, 0, indices), th.index_select(target, 0, indices)
# compute actual phase loss in angular space
Expand Down
30 changes: 16 additions & 14 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,32 @@ def save(self, model_dir, suffix=''):
if self.use_cuda:
self.cuda()

'''
load network parameters from model_dir/model_name.suffix.net
model_dir: (str) directory where the model should be stored
suffix: (str) optional suffix to append to the network name
'''
def load_from_file(self, model_file):
'''
load network parameters from model_file
:param model_file: file containing the model parameters
'''
if self.use_cuda:
self.cpu()

states = th.load(model_file)
self.load_state_dict(states)

if self.use_cuda:
self.cuda()
print(f"Loaded: {model_file}")

def load(self, model_dir, suffix=''):
'''
load network parameters from model_dir/model_name.suffix.net
:param model_dir: directory to load the model from
:param suffix: suffix to append after model name
'''
if self.use_cuda:
self.cpu()

if suffix == "":
fname = f"{model_dir}/{self.model_name}.net"
else:
fname = f"{model_dir}/{self.model_name}.{suffix}.net"

states = th.load(fname)
self.load_state_dict(states)
if self.use_cuda:
self.cuda()
print("Loaded:", fname)
self.load_from_file(fname)

def num_trainable_parameters(self):
'''
Expand Down
27 changes: 20 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,29 @@
"""

import os
import argparse

from src.dataset import BinauralDataset
from src.models import BinauralNetwork
from src.trainer import Trainer

dataset_dir = "/mnt/home/richardalex/tmp/bobatea/data/trainset"
artifacts_dir = "/mnt/home/richardalex/tmp/artifacts"

os.makedirs(artifacts_dir, exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_directory",
type=str,
default="./data/trainset",
help="path to the training data")
parser.add_argument("--artifacts_directory",
type=str,
default="./outputs",
help="directory to write model files to")
parser.add_argument("--num_gpus",
type=int,
default=4,
help="number of GPUs used during training")
args = parser.parse_args()

config = {
"artifacts_dir": artifacts_dir,
"artifacts_dir": args.artifacts_directory,
"learning_rate": 0.001,
"newbob_decay": 0.5,
"newbob_max_decay": 0.01,
Expand All @@ -27,10 +38,12 @@
"loss_weights": {"l2": 1.0, "phase": 0.01},
"save_frequency": 10,
"epochs": 100,
"num_gpus": 4,
"num_gpus": args.num_gpus,
}

dataset = BinauralDataset(dataset_directory=dataset_dir, chunk_size_ms=200, overlap=0.5)
os.makedirs(config["artifacts_dir"], exist_ok=True)

dataset = BinauralDataset(dataset_directory=args.dataset_directory, chunk_size_ms=200, overlap=0.5)

net = BinauralNetwork(view_dim=7,
warpnet_layers=4,
Expand Down

0 comments on commit 8d1b858

Please sign in to comment.