forked from dungtd2403/ivsr-s2p
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
97 lines (80 loc) · 3.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import tensorflow as tf
from model.model import DepthAwareNet, ParameterizedNet, BackboneSharedParameterizedNet, Kitti3DPredictor
from model.loss import L2DepthLoss, L2NormRMSE
from solver.optimizer import OptimizerFactory
import argparse
tf.keras.backend.clear_session()
parser = argparse.ArgumentParser(description='Select between small or big data',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--data-size', type=str, choices=['big', 'small', 'real', 'kitti'], default='small')
parser.add_argument('-m', '--training-mode', type=str, choices=['normal', 'parameterized', 'shared', 'kitti'], default='parameterized')
parser.add_argument('-b', '--batch-size', type=int, default=32)
parser.add_argument('-j', '--jobs', type=int, default=8)
args = parser.parse_args()
# input_shape = (180, 320)
input_shape = (180, 595)
################################
# Define data and dataloader #
################################
if args.data_size == 'big':
train_path = "./train_new.csv"
val_path = "./val_new.csv"
img_directory = "/media/data/teamAI/phuc/phuc/airsim/data"
elif args.data_size == 'kitti':
train_path = "/media/data/teamAI/minh/kitti_out/kitti_train.csv"
val_path = "/media/data/teamAI/minh/kitti_out/kitti_val.csv"
img_directory = "/media/data/teamAI/minh/kitti_out/semantic-0.4"
else:
train_path = "./train588_50_new.csv"
val_path = "./val588_50_new.csv"
img_directory = "/media/data/teamAI/phuc/phuc/airsim/50imperpose/full/"
if args.training_mode =='normal':
from data.parallel_dataset import Dataset, DataLoader
else:
from data.parameterized_parallel_dataset import Dataset, DataLoader
train_dataset = Dataset(train_path, img_directory, input_shape)
val_dataset = Dataset(val_path, img_directory, input_shape)
train_loader = DataLoader(train_dataset, input_shape=input_shape, batch_size=args.batch_size, num_parallel_calls=args.jobs)
val_loader = DataLoader(val_dataset, input_shape=input_shape, batch_size=args.batch_size, num_parallel_calls=args.jobs)
################
# Define model #
################
if args.training_mode =='parameterized':
net = ParameterizedNet(num_ext_conv=1)
elif args.training_mode == 'shared':
net = BackboneSharedParameterizedNet(num_ext_conv=1)
elif args.data_size == 'kitti':
net = Kitti3DPredictor(num_ext_conv=1)
else:
net = DepthAwareNet(num_ext_conv=0)
net.build(input_shape=(None, input_shape[0], input_shape[1], 1))
# inputs = tf.keras.Input(shape=(input_shape[0], input_shape[1], 1))
# _ = net.call(inputs)
net.summary()
#######################
# Define loss function#
#######################
USE_MSE = True
if USE_MSE:
dist_loss_fn = tf.keras.losses.MeanSquaredError()
depth_loss_fn = tf.keras.losses.MeanSquaredError()
else :
dist_loss_fn = L2NormRMSE()
depth_loss_fn = L2DepthLoss()
#######################
# Define optimizer#
#######################
factory = OptimizerFactory(lr=1e-3, use_scheduler=False)
optimizer = factory.get_optimizer()
#trainer and train
if args.training_mode =='normal':
from solver.trainer import Trainer
else:
from solver.parameterized_trainer import Trainer
trainer = Trainer(train_loader, val_loader=val_loader,
model=net, distance_loss_fn=dist_loss_fn, depth_loss_fn=depth_loss_fn,
optimizer=optimizer,
log_path='/media/data/teamAI/minh/ivsr-logs/training_kitti1507.txt', savepath='/media/data/teamAI/minh/ivsr_weights/training_kitti1507',
use_mse=USE_MSE)
_ = trainer.train(30, save_checkpoint=True, early_stop=True)
#trainer.save_model()