-
Notifications
You must be signed in to change notification settings - Fork 36
/
inference.py
65 lines (51 loc) · 1.8 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import print_function
import argparse
import os
import sys
import time
from PIL import Image
import tensorflow as tf
import numpy as np
from scipy import misc
from model import FCN8s, PSPNet50, ENet, ICNet
save_dir = './output/'
model_path = {'pspnet': './model/pspnet50.npy',
'fcn': './model/fcn.npy',
'enet': './model/cityscapes/enet.ckpt',
'icnet': './model/cityscapes/icnet.npy'}
def get_arguments():
parser = argparse.ArgumentParser(description="Reproduced PSPNet")
parser.add_argument("--img-path", type=str, default='',
help="Path to the RGB image file.",
required=True)
parser.add_argument("--save-dir", type=str, default=save_dir,
help="Path to save output.")
parser.add_argument("--model", type=str, default='',
help="pspnet or fcn",
choices=['pspnet', 'fcn', 'enet', 'icnet'],
required=True)
return parser.parse_args()
def main():
args = get_arguments()
if args.model == 'pspnet':
model = PSPNet50()
elif args.model == 'fcn':
model = FCN8s()
elif args.model == 'enet':
model = ENet()
elif args.model == 'icnet':
model = ICNet()
model.read_input(args.img_path)
# Init tf Session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
init = tf.global_variables_initializer()
sess.run(init)
model.load(model_path[args.model], sess)
preds = model.forward(sess)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
misc.imsave(args.save_dir + args.model + '_' + model.img_name, preds[0])
if __name__ == '__main__':
main()