diff --git a/models/mobilenet_ssd_pascal/README.md b/models/mobilenet_ssd_pascal/README.md new file mode 100644 index 0000000..80972b7 --- /dev/null +++ b/models/mobilenet_ssd_pascal/README.md @@ -0,0 +1,16 @@ +# Mobilenet SSD + +We offer the mobilenet(1.0) ssd model trained on PASCAL VOC0712 dataset. This model can be deployed on embedded system +and you can modify the network to adapt to your own application. + +## run the demo +1. Install PaddlePaddle(see: [PaddlePaddle installation instructions](http://paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html)) + +2. Download the [parameters](https://pan.baidu.com/s/1o7S8yWq) trained on PASCAL VOC0712. + +3. `python infer.py` + + +## train on your own dataset +You can modify the network to adapt to your own application. PaddlePaddle provides a detailed document to show how to train your model with SSD, refer the document [here](https://github.com/PaddlePaddle/models/tree/develop/ssd). + diff --git a/models/mobilenet_ssd_pascal/config/__init__.py b/models/mobilenet_ssd_pascal/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/mobilenet_ssd_pascal/config/label_list b/models/mobilenet_ssd_pascal/config/label_list new file mode 100644 index 0000000..87df23c --- /dev/null +++ b/models/mobilenet_ssd_pascal/config/label_list @@ -0,0 +1,21 @@ +background +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +diningtable +dog +horse +motorbike +person +pottedplant +sheep +sofa +train +tvmonitor diff --git a/models/mobilenet_ssd_pascal/config/pascal_voc_conf.py b/models/mobilenet_ssd_pascal/config/pascal_voc_conf.py new file mode 100644 index 0000000..0bf5af4 --- /dev/null +++ b/models/mobilenet_ssd_pascal/config/pascal_voc_conf.py @@ -0,0 +1,94 @@ +from easydict import EasyDict as edict +import numpy as np + +__C = edict() +cfg = __C + +__C.TRAIN = edict() + +__C.IMG_WIDTH = 300 +__C.IMG_HEIGHT = 300 +__C.IMG_CHANNEL = 3 +__C.CLASS_NUM = 21 +__C.BACKGROUND_ID = 0 + +# training settings +__C.TRAIN.MOMENTUM = 0.9 +__C.TRAIN.BATCH_SIZE = 32 +__C.TRAIN.LEARNING_RATE = 0.0005 / 4 +#__C.TRAIN.LEARNING_RATE = 0.001 / __C.TRAIN.BATCH_SIZE +__C.TRAIN.NUM_PASS = 50000 +__C.TRAIN.L2REGULARIZATION = 0.0005 * 4 +#__C.TRAIN.L2REGULARIZATION = 0.0005 * __C.TRAIN.BATCH_SIZE +__C.TRAIN.LEARNING_RATE_DECAY_A = 0.1 +__C.TRAIN.LEARNING_RATE_DECAY_B = 16551 * 80 +__C.TRAIN.LEARNING_RATE_SCHEDULE = 'discexp' +#__C.TRAIN.LEARNING_RATE_SCHEDULE = 'constant' + +__C.NET = edict() + +# configuration for multibox_loss_layer +__C.NET.MBLOSS = edict() +__C.NET.MBLOSS.OVERLAP_THRESHOLD = 0.5 +__C.NET.MBLOSS.NEG_POS_RATIO = 3.0 +__C.NET.MBLOSS.NEG_OVERLAP = 0.5 + +# configuration for detection_map +__C.NET.DETMAP = edict() +__C.NET.DETMAP.OVERLAP_THRESHOLD = 0.5 +__C.NET.DETMAP.EVAL_DIFFICULT = False +__C.NET.DETMAP.AP_TYPE = "11point" + +# configuration for detection_output_layer +__C.NET.DETOUT = edict() +__C.NET.DETOUT.CONFIDENCE_THRESHOLD = 0.01 +__C.NET.DETOUT.NMS_THRESHOLD = 0.45 +__C.NET.DETOUT.NMS_TOP_K = 400 +__C.NET.DETOUT.KEEP_TOP_K = 200 + +################################################ +__C.NET.CONV11 = edict() +__C.NET.CONV11.PB = edict() +__C.NET.CONV11.PB.MIN_SIZE = [60] +__C.NET.CONV11.PB.ASPECT_RATIO = [2.] +#__C.NET.CONV11.PB.ASPECT_RATIO = [2.] +__C.NET.CONV11.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +__C.NET.CONV13 = edict() +__C.NET.CONV13.PB = edict() +__C.NET.CONV13.PB.MIN_SIZE = [105] +__C.NET.CONV13.PB.MAX_SIZE = [150] +__C.NET.CONV13.PB.ASPECT_RATIO = [2., 3.] +__C.NET.CONV13.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +__C.NET.CONV14_2 = edict() +__C.NET.CONV14_2.PB = edict() +__C.NET.CONV14_2.PB.MIN_SIZE = [150] +__C.NET.CONV14_2.PB.MAX_SIZE = [195] +__C.NET.CONV14_2.PB.ASPECT_RATIO = [2., 3.] +__C.NET.CONV14_2.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +__C.NET.CONV15_2 = edict() +__C.NET.CONV15_2.PB = edict() +__C.NET.CONV15_2.PB.MIN_SIZE = [195] +__C.NET.CONV15_2.PB.MAX_SIZE = [240] +__C.NET.CONV15_2.PB.ASPECT_RATIO = [2., 3.] +__C.NET.CONV15_2.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + +__C.NET.CONV16_2 = edict() +__C.NET.CONV16_2.PB = edict() +__C.NET.CONV16_2.PB.MIN_SIZE = [240] +__C.NET.CONV16_2.PB.MAX_SIZE = [285] +__C.NET.CONV16_2.PB.ASPECT_RATIO = [2., 3.] +__C.NET.CONV16_2.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + +__C.NET.CONV17_2 = edict() +__C.NET.CONV17_2.PB = edict() +__C.NET.CONV17_2.PB.MIN_SIZE = [285] +__C.NET.CONV17_2.PB.MAX_SIZE = [300] +__C.NET.CONV17_2.PB.ASPECT_RATIO = [2., 3.] +__C.NET.CONV17_2.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + diff --git a/models/mobilenet_ssd_pascal/images/example.jpg b/models/mobilenet_ssd_pascal/images/example.jpg new file mode 100644 index 0000000..45f3023 Binary files /dev/null and b/models/mobilenet_ssd_pascal/images/example.jpg differ diff --git a/models/mobilenet_ssd_pascal/infer.py b/models/mobilenet_ssd_pascal/infer.py new file mode 100644 index 0000000..df7434d --- /dev/null +++ b/models/mobilenet_ssd_pascal/infer.py @@ -0,0 +1,88 @@ +import numpy as np +import gzip +import copy +import cv2, os + +import paddle.v2 as paddle +from mobilenet_ssd_pascal import net_conf +from config.pascal_voc_conf import cfg + +label_lists = open('./config/label_list').readlines() + +def _infer(inferer, infer_data, threshold): + ret = [] + infer_res = inferer.infer(input=infer_data) + keep_inds = np.where(infer_res[:, 2] >= threshold)[0] + for idx in keep_inds: + ret.append([ + infer_res[idx][0], infer_res[idx][1] - 1, infer_res[idx][2], + infer_res[idx][3], infer_res[idx][4], infer_res[idx][5], + infer_res[idx][6] + ]) + return ret + +def draw_result(frame, ret_res, h, w): + print ret_res + for det_res in ret_res: + img_idx = int(det_res[0]) + label = int(det_res[1]) + conf_score = det_res[2] + xmin = int(round(det_res[3] * w)) + ymin = int(round(det_res[4] * h)) + xmax = int(round(det_res[5] * w)) + ymax = int(round(det_res[6] * h)) + cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), + (0, (1 - xmin) * 255, xmin * 255), 2) + font=cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(frame, label_lists[label + 1].strip(), (xmin + 10, ymin + 10), font, 1.0, (255, 0, 0), 2) + + +def pre_process(img): + img = cv2.resize(img, (cfg.IMG_HEIGHT, cfg.IMG_WIDTH), interpolation=cv2.INTER_AREA) + # image should be RGB format + img = img[:, :, ::-1] + # image shoud be in CHW format + img = np.swapaxes(img, 1, 2) + img = np.swapaxes(img, 1, 0) + img = img.astype('float32') + + img_mean = np.array([104, 117, 124])[:, np.newaxis, np.newaxis].astype( + 'float32') + img -= img_mean + img = img.flatten() + return img + +def infer(model_path,threshold): + + net = net_conf(mode='infer') + + assert os.path.isfile(model_path), 'Invalid model.' + parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path)) + + #build the inference network + inferer = paddle.inference.Inference( + output_layer=net, parameters=parameters) + + test_data = [] + + frame = cv2.imread('./images/example.jpg') + + h, w, _ = frame.shape + img = copy.deepcopy(frame) + + # preprocess the image + img = pre_process(img) + test_data.append([img]) + + #the forward process + ret_res = _infer(inferer, test_data, threshold) + + draw_result(frame, ret_res, h, w) + cv2.imwrite('./images/result.jpg', frame) + +if __name__ == "__main__": + # init paddle environment + paddle.init(use_gpu=False, trainer_count=1, gpu_id= 3) + + infer(model_path='./mobilenet_ssd_pascal.tar.gz', + threshold=0.3) diff --git a/models/mobilenet_ssd_pascal/mobilenet_ssd_pascal.py b/models/mobilenet_ssd_pascal/mobilenet_ssd_pascal.py new file mode 100644 index 0000000..9d218f3 --- /dev/null +++ b/models/mobilenet_ssd_pascal/mobilenet_ssd_pascal.py @@ -0,0 +1,425 @@ +# edit-mode: -*- python -*- +import paddle.v2 as paddle +#from config.test_conf import cfg +from config.pascal_voc_conf import cfg + + +def net_conf(mode, scale = 1.0): + """Network configuration. Total three modes included 'train' 'eval' + and 'infer'. Loss and mAP evaluation layer will return if using 'train' + and 'eval'. In 'infer' mode, only detection output layer will be returned. + """ + default_l2regularization = cfg.TRAIN.L2REGULARIZATION + + default_bias_attr = paddle.attr.ParamAttr(l2_rate=0.0, learning_rate=2.0) + default_static_bias_attr = paddle.attr.ParamAttr(is_static=True) + + def get_param_attr(local_lr, regularization): + is_static = False + if local_lr == 0.0: + is_static = True + return paddle.attr.ParamAttr( + learning_rate=local_lr, l2_rate=regularization, is_static=is_static) + + def mbox_block(layer_name, input, num_channels, filter_size, loc_filters, + conf_filters): + #mbox_loc_name = layer_idx + "_mbox_loc" + mbox_loc = paddle.layer.img_conv( + #name = layer_name + '_' + 'loc', + input=input, + filter_size=filter_size, + num_channels=num_channels, + num_filters=loc_filters, + stride=1, + padding=0, + layer_type='exconv', + bias_attr=default_bias_attr, + param_attr=get_param_attr(1, default_l2regularization), + act=paddle.activation.Identity()) + + #mbox_conf_name = layer_idx + "_mbox_conf" + mbox_conf = paddle.layer.img_conv( + #name = layer_name + '_' + 'conf', + input=input, + filter_size=filter_size, + num_channels=num_channels, + num_filters=conf_filters, + stride=1, + padding=0, + layer_type='exconv', + bias_attr=default_bias_attr, + param_attr=get_param_attr(1, default_l2regularization), + act=paddle.activation.Identity()) + + return mbox_loc, mbox_conf + + def conv_bn_layer(input, filter_size, num_filters, + stride, padding, channels=None, num_groups=1, + active_type=paddle.activation.Relu(), name = None): + """ + A wrapper for conv layer with batch normalization layers. + Note: + conv layer has no activation. + """ + tmp = paddle.layer.img_conv( + #name = name, + input=input, + filter_size=filter_size, + num_channels=channels, + num_filters=num_filters, + stride=stride, + padding=padding, + groups=num_groups, + layer_type='exconv', + # !!! the act in the network with batch norm + # is paddle.activation.Linear() + act=active_type, + # !!! the bias_attr in origin network is False + bias_attr=True) + #print tmp.name + + # !!! we have deleted the batch_norm layer here. + return tmp + + def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride): + """ + """ + tmp = conv_bn_layer(input=input, + filter_size=3, + num_filters=num_filters1, + stride=stride, + padding=1, + num_groups=num_groups) + + tmp = conv_bn_layer(input=tmp, + filter_size=1, + num_filters=num_filters2, + stride=1, + padding=0) + return tmp + + img = paddle.layer.data( + name='image', + type=paddle.data_type.dense_vector(cfg.IMG_CHANNEL * cfg.IMG_HEIGHT * + cfg.IMG_WIDTH), + height=cfg.IMG_HEIGHT, + width=cfg.IMG_WIDTH) + + # conv1: 112x112 + #"conv0" "conv0/relu" + conv0 = conv_bn_layer(img, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1) + + # 56x56 + # "conv1/dw" "conv1/dw/relu" "conv1" "conv1/relu" + conv1 = depthwise_separable(conv0, + num_filters1=int(32 * scale), + num_filters2=int(64 * scale), + num_groups=int(32 * scale), + stride=1) + + #"conv2/dw" "conv2/dw/relu" "conv2" "conv2/relu" + conv2 = depthwise_separable(conv1, + num_filters1=int(64 * scale), + num_filters2=int(128 * scale), + num_groups=int(64 * scale), + stride=2) + # 28x28 + #"conv3/dw" "conv3/dw/relu" "conv3" "conv3/relu" + conv3 = depthwise_separable(conv2, + num_filters1=int(128*scale), + num_filters2=int(128*scale), + num_groups=int(128*scale), + stride=1) + + #"conv4/dw" "conv4/dw/relu" "conv4" "conv4/relu" + conv4 = depthwise_separable(conv3, + num_filters1=int(128*scale), + num_filters2=int(256*scale), + num_groups=int(128*scale), + stride=2) + + # 14x14 + #"conv5/dw" "conv5/dw/relu" "conv5" "conv5/relu" + conv5 = depthwise_separable(conv4, + num_filters1=int(256*scale), + num_filters2=int(256*scale), + num_groups=int(256*scale), + stride=1) + + #"conv6/dw" "conv6/dw/relu" "conv6" "conv6/relu" + conv6 = depthwise_separable(conv5, + num_filters1=int(256*scale), + num_filters2=int(512*scale), + num_groups=int(256*scale), + stride=2) + + tmp = conv6 + + # 14x14 + #"conv7/dw" "conv7/dw/relu" "conv7" "conv7/relu" + #conv7~11 + for i in range(5): + tmp = depthwise_separable(tmp, + num_filters1=int(512*scale), + num_filters2=int(512*scale), + num_groups=int(512*scale), + stride=1) + conv11 = tmp + + # 7x7 + #"conv12/dw" "conv12/dw/relu" "conv12" "conv12/relu" + conv12 = depthwise_separable(conv11, + num_filters1=int(512*scale), + num_filters2=int(1024*scale), + num_groups=int(512*scale), + stride=2) + + #"conv13/dw" "conv13/dw/relu" "conv13" "conv13/relu" + conv13 = depthwise_separable(conv12, + num_filters1=int(1024*scale), + num_filters2=int(1024*scale), + num_groups=int(1024*scale), + stride=1) + + # add begin + # conv14_1 "conv14_1/relu" + conv14_1 = conv_bn_layer( + #name = 'module3_1', + input=conv13, + filter_size=1, + num_filters=int(256*scale), + stride=1, + padding=0) + + #conv14_2 "conv14_2/relu" + conv14_2 = conv_bn_layer( + #name = 'module3_2', + input=conv14_1, + filter_size=3, + num_filters=int(512*scale), + stride=2, + padding=1) + + #conv15_1 "conv15_1/relu" + conv15_1 = conv_bn_layer( + # name = 'module4_1', + input=conv14_2, + filter_size=1, + num_filters=int(128*scale), + stride=1, + padding=0) + + #"conv15_2" "conv15_2/relu" + conv15_2 = conv_bn_layer( + #name = 'module4_2', + input=conv15_1, + filter_size=3, + num_filters=int(256*scale), + stride=2, + padding=1) + + #conv16_1 "conv16_1/relu" + conv16_1 = conv_bn_layer( + #name = 'module5_1', + input=conv15_2, + filter_size=1, + num_filters=int(128*scale), + stride=1, + padding=0) + + #"conv16_2" "conv16_2/relu" + conv16_2 = conv_bn_layer( + #name = 'module5_2', + input=conv16_1, + filter_size=3, + num_filters=int(256*scale), + stride=2, + padding=1) + + #conv17_1 conv17_1/relu + conv17_1 = conv_bn_layer( + #name = 'module6_1', + input=conv16_2, + filter_size=1, + num_filters=int(64*scale), + stride=1, + padding=0) + + #conv17_2 conv17_2/relu + conv17_2 = conv_bn_layer( + #name = 'module6_2', + input=conv17_1, + filter_size=3, + num_filters=int(128*scale), + stride=2, + padding=1) + + conv11_mbox_priorbox = paddle.layer.priorbox( + input=conv11, + image=img, + min_size=cfg.NET.CONV11.PB.MIN_SIZE, + aspect_ratio=cfg.NET.CONV11.PB.ASPECT_RATIO, + variance=cfg.NET.CONV11.PB.VARIANCE) + + conv11_norm = paddle.layer.cross_channel_norm( + name="conv11_norm", + input=conv11, + param_attr=paddle.attr.ParamAttr( + initial_mean=20, initial_std=0, is_static=False, learning_rate=1)) + + conv11_mbox_loc, conv11_mbox_conf= \ + mbox_block("module1", conv11_norm, int(512*scale), 1, 12, 63) # kernel_size=1 + + + conv13_mbox_priorbox = paddle.layer.priorbox( + input=conv13, + image=img, + min_size=cfg.NET.CONV13.PB.MIN_SIZE, + max_size=cfg.NET.CONV13.PB.MAX_SIZE, + aspect_ratio=cfg.NET.CONV13.PB.ASPECT_RATIO, + variance=cfg.NET.CONV13.PB.VARIANCE) + conv13_norm = paddle.layer.cross_channel_norm( + name="conv13_norm", + input=conv13, + param_attr=paddle.attr.ParamAttr( + initial_mean=20, initial_std=0, is_static=False, learning_rate=1)) + conv13_mbox_loc, conv13_mbox_conf= \ + mbox_block("module2", conv13_norm, int(1024*scale), 1, 24, 126) + + conv14_2_mbox_priorbox = paddle.layer.priorbox( + input=conv14_2, + image=img, + min_size=cfg.NET.CONV14_2.PB.MIN_SIZE, + max_size=cfg.NET.CONV14_2.PB.MAX_SIZE, + aspect_ratio=cfg.NET.CONV14_2.PB.ASPECT_RATIO, + variance=cfg.NET.CONV14_2.PB.VARIANCE) + conv14_2_norm = paddle.layer.cross_channel_norm( + name="conv14_2", + input=conv14_2, + param_attr=paddle.attr.ParamAttr( + initial_mean=20, initial_std=0, is_static=False, learning_rate=1)) + conv14_2_mbox_loc, conv14_2_mbox_conf= \ + mbox_block("module3", conv14_2_norm, int(512*scale), 1, 24, 126) + + conv15_2_mbox_priorbox = paddle.layer.priorbox( + input=conv15_2, + image=img, + min_size=cfg.NET.CONV15_2.PB.MIN_SIZE, + max_size=cfg.NET.CONV15_2.PB.MAX_SIZE, + aspect_ratio=cfg.NET.CONV15_2.PB.ASPECT_RATIO, + variance=cfg.NET.CONV15_2.PB.VARIANCE) + conv15_2_norm = paddle.layer.cross_channel_norm( + name="conv15_2_norm", + input=conv15_2, + param_attr=paddle.attr.ParamAttr( + initial_mean=20, initial_std=0, is_static=False, learning_rate=1)) + + conv15_2_mbox_loc, conv15_2_mbox_conf= \ + mbox_block("module4", conv15_2_norm, int(256*scale), 1, 24, 126) + + conv16_2_mbox_priorbox = paddle.layer.priorbox( + input=conv16_2, + image=img, + min_size=cfg.NET.CONV16_2.PB.MIN_SIZE, + max_size=cfg.NET.CONV16_2.PB.MAX_SIZE, + aspect_ratio=cfg.NET.CONV16_2.PB.ASPECT_RATIO, + variance=cfg.NET.CONV16_2.PB.VARIANCE) + conv16_2_norm = paddle.layer.cross_channel_norm( + name="conv16_2_norm", + input=conv16_2, + param_attr=paddle.attr.ParamAttr( + initial_mean=20, initial_std=0, is_static=False, learning_rate=1)) + conv16_2_mbox_loc, conv16_2_mbox_conf= \ + mbox_block("module5", conv16_2_norm, int(256*scale), 1, 24, 126) + + conv17_2_mbox_priorbox = paddle.layer.priorbox( + input=conv17_2, + image=img, + min_size=cfg.NET.CONV17_2.PB.MIN_SIZE, + max_size=cfg.NET.CONV17_2.PB.MAX_SIZE, + aspect_ratio=cfg.NET.CONV17_2.PB.ASPECT_RATIO, + variance=cfg.NET.CONV17_2.PB.VARIANCE) + conv17_2_norm = paddle.layer.cross_channel_norm( + name="conv17_2_norm", + input=conv17_2, + param_attr=paddle.attr.ParamAttr( + initial_mean=20, initial_std=0, is_static=False, learning_rate=1)) + conv17_2_mbox_loc, conv17_2_mbox_conf= \ + mbox_block("module6", conv17_2_norm, int(128*scale), 1, 24, 126) + + mbox_priorbox = paddle.layer.concat( + name="mbox_priorbox", + input=[ + conv11_mbox_priorbox + , conv13_mbox_priorbox + , conv14_2_mbox_priorbox + , conv15_2_mbox_priorbox + , conv16_2_mbox_priorbox + , conv17_2_mbox_priorbox + ]) + + + loc_loss_input = [ + conv11_mbox_loc + , conv13_mbox_loc + , conv14_2_mbox_loc + , conv15_2_mbox_loc + , conv16_2_mbox_loc + , conv17_2_mbox_loc + ] + + conf_loss_input = [ + conv11_mbox_conf + , conv13_mbox_conf + , conv14_2_mbox_conf + , conv15_2_mbox_conf + , conv16_2_mbox_conf + , conv17_2_mbox_conf + ] + + + detection_out = paddle.layer.detection_output( + input_loc=loc_loss_input, + input_conf=conf_loss_input, + priorbox=mbox_priorbox, + confidence_threshold=cfg.NET.DETOUT.CONFIDENCE_THRESHOLD, + nms_threshold=cfg.NET.DETOUT.NMS_THRESHOLD, + num_classes=cfg.CLASS_NUM, + nms_top_k=cfg.NET.DETOUT.NMS_TOP_K, + keep_top_k=cfg.NET.DETOUT.KEEP_TOP_K, + background_id=cfg.BACKGROUND_ID, + name="detection_output") + + if mode == 'train' or mode == 'eval': + bbox = paddle.layer.data( + name='bbox', type=paddle.data_type.dense_vector_sequence(6)) + loss = paddle.layer.multibox_loss( + input_loc=loc_loss_input, + input_conf=conf_loss_input, + priorbox=mbox_priorbox, + label=bbox, + num_classes=cfg.CLASS_NUM, + overlap_threshold=cfg.NET.MBLOSS.OVERLAP_THRESHOLD, + neg_pos_ratio=cfg.NET.MBLOSS.NEG_POS_RATIO, + neg_overlap=cfg.NET.MBLOSS.NEG_OVERLAP, + background_id=cfg.BACKGROUND_ID, + name="multibox_loss") + paddle.evaluator.detection_map( + input=detection_out, + label=bbox, + overlap_threshold=cfg.NET.DETMAP.OVERLAP_THRESHOLD, + background_id=cfg.BACKGROUND_ID, + evaluate_difficult=cfg.NET.DETMAP.EVAL_DIFFICULT, + ap_type=cfg.NET.DETMAP.AP_TYPE, + name="detection_evaluator") + return loss, detection_out + elif mode == 'infer': + return detection_out + +if __name__ == '__main__': + out = net_conf('infer', scale = 1.0)