Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

add mobilenet ssd pascal model #40

Merged
merged 2 commits into from
Dec 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions models/mobilenet_ssd_pascal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Mobilenet SSD

We offer the mobilenet(1.0) ssd model trained on PASCAL VOC0712 dataset. This model can be deployed on embedded system
and you can modify the network to adapt to your own application.

## run the demo
1. Install PaddlePaddle(see: [PaddlePaddle installation instructions](http://paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html))

2. Download the [parameters](https://pan.baidu.com/s/1o7S8yWq) trained on PASCAL VOC0712.

3. `python infer.py`


## train on your own dataset
You can modify the network to adapt to your own application. PaddlePaddle provides a detailed document to show how to train your model with SSD, refer the document [here](https://github.com/PaddlePaddle/models/tree/develop/ssd).

Empty file.
21 changes: 21 additions & 0 deletions models/mobilenet_ssd_pascal/config/label_list
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
background
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
94 changes: 94 additions & 0 deletions models/mobilenet_ssd_pascal/config/pascal_voc_conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from easydict import EasyDict as edict
import numpy as np

__C = edict()
cfg = __C

__C.TRAIN = edict()

__C.IMG_WIDTH = 300
__C.IMG_HEIGHT = 300
__C.IMG_CHANNEL = 3
__C.CLASS_NUM = 21
__C.BACKGROUND_ID = 0

# training settings
__C.TRAIN.MOMENTUM = 0.9
__C.TRAIN.BATCH_SIZE = 32
__C.TRAIN.LEARNING_RATE = 0.0005 / 4
#__C.TRAIN.LEARNING_RATE = 0.001 / __C.TRAIN.BATCH_SIZE
__C.TRAIN.NUM_PASS = 50000
__C.TRAIN.L2REGULARIZATION = 0.0005 * 4
#__C.TRAIN.L2REGULARIZATION = 0.0005 * __C.TRAIN.BATCH_SIZE
__C.TRAIN.LEARNING_RATE_DECAY_A = 0.1
__C.TRAIN.LEARNING_RATE_DECAY_B = 16551 * 80
__C.TRAIN.LEARNING_RATE_SCHEDULE = 'discexp'
#__C.TRAIN.LEARNING_RATE_SCHEDULE = 'constant'

__C.NET = edict()

# configuration for multibox_loss_layer
__C.NET.MBLOSS = edict()
__C.NET.MBLOSS.OVERLAP_THRESHOLD = 0.5
__C.NET.MBLOSS.NEG_POS_RATIO = 3.0
__C.NET.MBLOSS.NEG_OVERLAP = 0.5

# configuration for detection_map
__C.NET.DETMAP = edict()
__C.NET.DETMAP.OVERLAP_THRESHOLD = 0.5
__C.NET.DETMAP.EVAL_DIFFICULT = False
__C.NET.DETMAP.AP_TYPE = "11point"

# configuration for detection_output_layer
__C.NET.DETOUT = edict()
__C.NET.DETOUT.CONFIDENCE_THRESHOLD = 0.01
__C.NET.DETOUT.NMS_THRESHOLD = 0.45
__C.NET.DETOUT.NMS_TOP_K = 400
__C.NET.DETOUT.KEEP_TOP_K = 200

################################################
__C.NET.CONV11 = edict()
__C.NET.CONV11.PB = edict()
__C.NET.CONV11.PB.MIN_SIZE = [60]
__C.NET.CONV11.PB.ASPECT_RATIO = [2.]
#__C.NET.CONV11.PB.ASPECT_RATIO = [2.]
__C.NET.CONV11.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]


__C.NET.CONV13 = edict()
__C.NET.CONV13.PB = edict()
__C.NET.CONV13.PB.MIN_SIZE = [105]
__C.NET.CONV13.PB.MAX_SIZE = [150]
__C.NET.CONV13.PB.ASPECT_RATIO = [2., 3.]
__C.NET.CONV13.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]


__C.NET.CONV14_2 = edict()
__C.NET.CONV14_2.PB = edict()
__C.NET.CONV14_2.PB.MIN_SIZE = [150]
__C.NET.CONV14_2.PB.MAX_SIZE = [195]
__C.NET.CONV14_2.PB.ASPECT_RATIO = [2., 3.]
__C.NET.CONV14_2.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]


__C.NET.CONV15_2 = edict()
__C.NET.CONV15_2.PB = edict()
__C.NET.CONV15_2.PB.MIN_SIZE = [195]
__C.NET.CONV15_2.PB.MAX_SIZE = [240]
__C.NET.CONV15_2.PB.ASPECT_RATIO = [2., 3.]
__C.NET.CONV15_2.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]

__C.NET.CONV16_2 = edict()
__C.NET.CONV16_2.PB = edict()
__C.NET.CONV16_2.PB.MIN_SIZE = [240]
__C.NET.CONV16_2.PB.MAX_SIZE = [285]
__C.NET.CONV16_2.PB.ASPECT_RATIO = [2., 3.]
__C.NET.CONV16_2.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]

__C.NET.CONV17_2 = edict()
__C.NET.CONV17_2.PB = edict()
__C.NET.CONV17_2.PB.MIN_SIZE = [285]
__C.NET.CONV17_2.PB.MAX_SIZE = [300]
__C.NET.CONV17_2.PB.ASPECT_RATIO = [2., 3.]
__C.NET.CONV17_2.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]

Binary file added models/mobilenet_ssd_pascal/images/example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
88 changes: 88 additions & 0 deletions models/mobilenet_ssd_pascal/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
import gzip
import copy
import cv2, os

import paddle.v2 as paddle
from mobilenet_ssd_pascal import net_conf
from config.pascal_voc_conf import cfg

label_lists = open('./config/label_list').readlines()

def _infer(inferer, infer_data, threshold):
ret = []
infer_res = inferer.infer(input=infer_data)
keep_inds = np.where(infer_res[:, 2] >= threshold)[0]
for idx in keep_inds:
ret.append([
infer_res[idx][0], infer_res[idx][1] - 1, infer_res[idx][2],
infer_res[idx][3], infer_res[idx][4], infer_res[idx][5],
infer_res[idx][6]
])
return ret

def draw_result(frame, ret_res, h, w):
print ret_res
for det_res in ret_res:
img_idx = int(det_res[0])
label = int(det_res[1])
conf_score = det_res[2]
xmin = int(round(det_res[3] * w))
ymin = int(round(det_res[4] * h))
xmax = int(round(det_res[5] * w))
ymax = int(round(det_res[6] * h))
cv2.rectangle(frame, (xmin, ymin), (xmax, ymax),
(0, (1 - xmin) * 255, xmin * 255), 2)
font=cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(frame, label_lists[label + 1].strip(), (xmin + 10, ymin + 10), font, 1.0, (255, 0, 0), 2)


def pre_process(img):
img = cv2.resize(img, (cfg.IMG_HEIGHT, cfg.IMG_WIDTH), interpolation=cv2.INTER_AREA)
# image should be RGB format
img = img[:, :, ::-1]
# image shoud be in CHW format
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
img = img.astype('float32')

img_mean = np.array([104, 117, 124])[:, np.newaxis, np.newaxis].astype(
'float32')
img -= img_mean
img = img.flatten()
return img

def infer(model_path,threshold):

net = net_conf(mode='infer')

assert os.path.isfile(model_path), 'Invalid model.'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))

#build the inference network
inferer = paddle.inference.Inference(
output_layer=net, parameters=parameters)

test_data = []

frame = cv2.imread('./images/example.jpg')

h, w, _ = frame.shape
img = copy.deepcopy(frame)

# preprocess the image
img = pre_process(img)
test_data.append([img])

#the forward process
ret_res = _infer(inferer, test_data, threshold)

draw_result(frame, ret_res, h, w)
cv2.imwrite('./images/result.jpg', frame)

if __name__ == "__main__":
# init paddle environment
paddle.init(use_gpu=False, trainer_count=1, gpu_id= 3)

infer(model_path='./mobilenet_ssd_pascal.tar.gz',
threshold=0.3)
Loading