-
Notifications
You must be signed in to change notification settings - Fork 82
/
trainval.py
101 lines (83 loc) · 3.64 KB
/
trainval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import ast
import os
import torch
import yaml
from src.processor import processor
# Use Deterministic mode and set random seed
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(0)
def get_parser():
parser = argparse.ArgumentParser(
description='STAR')
parser.add_argument('--dataset', default='eth5')
parser.add_argument('--save_dir')
parser.add_argument('--model_dir')
parser.add_argument('--config')
parser.add_argument('--using_cuda', default=True, type=ast.literal_eval)
parser.add_argument('--test_set', default='eth', type=str,
help='Set this value to [eth, hotel, zara1, zara2, univ] for ETH-univ, ETH-hotel, UCY-zara01, UCY-zara02, UCY-univ')
parser.add_argument('--base_dir', default='.', help='Base directory including these scripts.')
parser.add_argument('--save_base_dir', default='./output/', help='Directory for saving caches and models.')
parser.add_argument('--phase', default='train', help='Set this value to \'train\' or \'test\'')
parser.add_argument('--train_model', default='star', help='Your model name')
parser.add_argument('--load_model', default=None, type=str, help="load pretrained model for test or training")
parser.add_argument('--model', default='star.STAR')
parser.add_argument('--seq_length', default=20, type=int)
parser.add_argument('--obs_length', default=8, type=int)
parser.add_argument('--pred_length', default=12, type=int)
parser.add_argument('--batch_around_ped', default=256, type=int)
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--test_batch_size', default=4, type=int)
parser.add_argument('--show_step', default=100, type=int)
parser.add_argument('--start_test', default=10, type=int)
parser.add_argument('--sample_num', default=20, type=int)
parser.add_argument('--num_epochs', default=300, type=int)
parser.add_argument('--ifshow_detail', default=True, type=ast.literal_eval)
parser.add_argument('--ifsave_results', default=False, type=ast.literal_eval)
parser.add_argument('--randomRotate', default=True, type=ast.literal_eval,
help="=True:random rotation of each trajectory fragment")
parser.add_argument('--neighbor_thred', default=10, type=int)
parser.add_argument('--learning_rate', default=0.0015, type=float)
parser.add_argument('--clip', default=1, type=int)
return parser
def load_arg(p):
# save arg
if os.path.exists(p.config):
with open(p.config, 'r') as f:
default_arg = yaml.load(f)
key = vars(p).keys()
for k in default_arg.keys():
if k not in key:
print('WRONG ARG: {}'.format(k))
try:
assert (k in key)
except:
s = 1
parser.set_defaults(**default_arg)
return parser.parse_args()
else:
return False
def save_arg(args):
# save arg
arg_dict = vars(args)
if not os.path.exists(args.model_dir):
os.makedirs(args.model_dir)
with open(args.config, 'w') as f:
yaml.dump(arg_dict, f)
if __name__ == '__main__':
parser = get_parser()
p = parser.parse_args()
p.save_dir = p.save_base_dir + str(p.test_set) + '/'
p.model_dir = p.save_base_dir + str(p.test_set) + '/' + p.train_model + '/'
p.config = p.model_dir + '/config_' + p.phase + '.yaml'
if not load_arg(p):
save_arg(p)
args = load_arg(p)
torch.cuda.set_device(0)
trainer = processor(args)
if args.phase == 'test':
trainer.test()
else:
trainer.train()