This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
/
train.py
63 lines (54 loc) · 1.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import argparse
from src.dataset import BinauralDataset
from src.models import BinauralNetwork
from src.trainer import Trainer
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_directory",
type=str,
default="./data/trainset",
help="path to the training data")
parser.add_argument("--artifacts_directory",
type=str,
default="./outputs",
help="directory to write model files to")
parser.add_argument("--num_gpus",
type=int,
default=4,
help="number of GPUs used during training")
parser.add_argument("--blocks",
type=int,
default=3)
args = parser.parse_args()
config = {
"artifacts_dir": args.artifacts_directory,
"learning_rate": 0.001,
"newbob_decay": 0.5,
"newbob_max_decay": 0.01,
"batch_size": 32,
"mask_beginning": 1024,
"loss_weights": {"l2": 1.0, "phase": 0.01},
"save_frequency": 10,
"epochs": 100,
"num_gpus": args.num_gpus,
}
os.makedirs(config["artifacts_dir"], exist_ok=True)
dataset = BinauralDataset(dataset_directory=args.dataset_directory, chunk_size_ms=200, overlap=0.5)
net = BinauralNetwork(view_dim=7,
warpnet_layers=4,
warpnet_channels=64,
wavenet_blocks=args.blocks,
layers_per_block=10,
wavenet_channels=64
)
print(f"receptive field: {net.receptive_field()}")
print(f"train on {len(dataset.chunks)} chunks")
print(f"number of trainable parameters: {net.num_trainable_parameters()}")
trainer = Trainer(config, net, dataset)
trainer.train()