-
Notifications
You must be signed in to change notification settings - Fork 18
/
train.py
115 lines (89 loc) · 3.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Trains, evaluates and saves the TensorDetect model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
import os
import sys
import collections
# configure logging
if 'TV_IS_DEV' in os.environ and os.environ['TV_IS_DEV']:
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
else:
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
# https://github.com/tensorflow/tensorflow/issues/2034#issuecomment-220820070
import numpy as np
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
sys.path.insert(1, 'incl')
import tensorvision.train as train
import tensorvision.utils as utils
flags.DEFINE_string('name', None,
'Append a name Tag to run.')
flags.DEFINE_string('project', None,
'Append a name Tag to run.')
flags.DEFINE_string('hypes', None,
'File storing model parameters.')
flags.DEFINE_string('mod', None,
'Modifier for model parameters.')
tf.app.flags.DEFINE_boolean(
'save', True, ('Whether to save the run. In case --nosave (default) '
'output will be saved to the folder TV_DIR_RUNS/debug, '
'hence it will get overwritten by further runs.'))
def dict_merge(dct, merge_dct):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
:param dct: dict onto which the merge is executed
:param merge_dct: dct merged into dct
:return: None
"""
for k, v in merge_dct.iteritems():
if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collections.Mapping)):
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
def main(_):
utils.set_gpus_to_use()
try:
import tensorvision.train
import tensorflow_fcn.utils
except ImportError:
logging.error("Could not import the submodules.")
logging.error("Please execute:"
"'git submodule update --init --recursive'")
exit(1)
if tf.app.flags.FLAGS.hypes is None:
logging.error("No hype file is given.")
logging.info("Usage: python train.py --hypes hypes/KittiClass.json")
exit(1)
with open(tf.app.flags.FLAGS.hypes, 'r') as f:
logging.info("f: %s", f)
hypes = json.load(f)
utils.load_plugins()
if tf.app.flags.FLAGS.mod is not None:
import ast
mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)
dict_merge(hypes, mod_dict)
if 'TV_DIR_RUNS' in os.environ:
os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
'KittiClass')
utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)
utils._add_paths_to_sys(hypes)
logging.info("Initialize training folder")
train.initialize_training_folder(hypes)
train.maybe_download_and_extract(hypes)
logging.info("Start training")
train.do_training(hypes)
if __name__ == '__main__':
tf.app.run()