-
Notifications
You must be signed in to change notification settings - Fork 6
/
demo.py
69 lines (62 loc) · 3.24 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import mxnet as mx
import os
import sys
from cam import Cam
from cam import Cam_resp
def parse_args():
parser = argparse.ArgumentParser(description='Class activation mapping demo')
parser.add_argument('--network', dest='network', type=str, default='densenet121',
help='which network to use')
parser.add_argument('--images', dest='images', type=str, default='./data/demo/00000377_004.png',
help='run demo with images, use comma to seperate multiple images')
parser.add_argument('--model-path', dest='model_path', type=str,
default=os.path.join(os.getcwd(), 'model'),
help='trained model path')
parser.add_argument('--cpu', dest='cpu', help='(override GPU) use CPU to detect',
action='store_true', default=False)
parser.add_argument('--gpu', dest='gpu_id', type=int, default=0,
help='GPU device id to detect with')
parser.add_argument('--data-shape', dest='data_shape', type=int, default=224,
help='set image shape')
parser.add_argument('--thresh', dest='thresh', type=float, default=0.5,
help='object visualize score threshold, default 0.5')
parser.add_argument('--num-class', dest='num_class', type=int, default=14,
help='number of classes')
parser.add_argument('--class-names', dest='class_names', type=str,
default='Atelectasis, Cardiomegaly, Effusion, Infiltration, Mass, Nodule, Pneumonia, Pneumothorax, Consolidation, Edema, Emphysema, Fibrosis, Pleural_Thickening, Hernia',
help='string of comma separated names, or text filename')
parser.add_argument('--identifier', dest='identifier', type=int, default=-1,
help='Use combined model if it equals -1.Use all models if it equals 1')
args = parser.parse_args()
return args
def parse_class_names(class_names):
""" parse # classes and class_names if applicable """
if len(class_names) > 0:
if os.path.isfile(class_names):
# try to open it to read class names
with open(class_names, 'r') as f:
class_names = [l.strip() for l in f.readlines()]
else:
class_names = [c.strip() for c in class_names.split(',')]
for name in class_names:
assert len(name) > 0
else:
raise RuntimeError("No valid class_name provided...")
return class_names
if __name__ == '__main__':
args = parse_args()
if args.cpu:
ctx = mx.cpu()
else:
ctx = mx.gpu(args.gpu_id)
# parse image list
image_list = [i.strip() for i in args.images.split(',')]
assert len(image_list) > 0, "No valid image specified to detect"
network = args.network
class_names = parse_class_names(args.class_names)
# run
if args.identifier == -1:
Cam(network, image_list, args.model_path, ctx, args.data_shape, class_names, args.thresh, num_class=args.num_class)
elif args.identifier == 1:
Cam_resp(network, image_list, args.model_path, ctx, args.data_shape, class_names, args.thresh, num_class=args.num_class)