-
Notifications
You must be signed in to change notification settings - Fork 21
/
train.py
59 lines (50 loc) · 2.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import sys
import torch
from machamp.model import trainer
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_configs", nargs='+',
help="Path(s) to dataset configurations (use --sequential to train on them sequentially, "
"default is joint training).")
parser.add_argument("--name", default="", type=str, help="Log dir name.")
parser.add_argument("--sequential", action="store_true",
help="Enables finetuning sequentially, this will train the same weights once for each "
"dataset_config you pass.")
parser.add_argument("--parameters_config", default="configs/params.json", type=str,
help="Configuration file for parameters of the model.")
parser.add_argument("--device", default=None, type=int, help="CUDA device; set to -1 for CPU.")
model_dir_group = parser.add_mutually_exclusive_group()
model_dir_group.add_argument("--resume", default='', type=str,
help='Finalize training on a model for which training abruptly stopped. Give the path to the log '
'directory of the model.')
model_dir_group.add_argument("--model_dir", default=None, type=str,
help='Specify a directory to store model and logs in. Overrides the default.')
parser.add_argument("--retrain", type=str, default='',
help="Retrain on an previously train MaChAmp model. Specify the path to model.tar.gz and add a "
"dataset_config that specifies the new training.")
parser.add_argument("--seed", type=int, default=8446, help="seed to use for training.")
args = parser.parse_args()
if args.resume == '' and (args.dataset_configs == None or len(args.dataset_configs) == 0):
print('Please provide at least 1 dataset configuration')
exit(1)
if args.device == None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
elif args.device == -1:
device = 'cpu'
else:
device = 'cuda:' + str(args.device)
name = args.name
if args.resume == '' and name == '':
names = [name[name.rfind('/') + 1: name.rfind('.') if '.' in name else len(name)] for name in args.dataset_configs]
name = '.'.join(names)
if args.resume != '':
name = args.resume.split('/')[1]
cmd = ' '.join(sys.argv)
if args.sequential:
prevDir = trainer.train(name + '.0', args.parameters_config, [args.dataset_configs[0]], device, args.resume, args.retrain,
args.seed, cmd)
for datasetIdx, dataset in enumerate(args.dataset_configs[1:]):
modelName = name + '.' + str(datasetIdx + 1)
prevDir = trainer.train(modelName, args.parameters_config, [dataset], device, None, prevDir, args.seed, cmd, args.model_dir)
else:
trainer.train(name, args.parameters_config, args.dataset_configs, device, args.resume, args.retrain, args.seed, cmd, args.model_dir)