-
Notifications
You must be signed in to change notification settings - Fork 0
/
rpn_vgg16.py
30 lines (27 loc) · 1.32 KB
/
rpn_vgg16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.models import Model, Sequential
def get_rpn_model(hyper_params):
"""Generating rpn model for given hyper params.
inputs:
hyper_params = dictionary
outputs:
rpn_model = tf.keras.model
feature_extractor = feature extractor layer from the base model
"""
img_size = hyper_params["img_size"]
base_model = VGG16(include_top=False, input_shape=(img_size, img_size, 3))
# print(base_model.summary())
feature_extractor = base_model.get_layer("block5_conv3")
output = Conv2D(512, (3, 3), activation="relu", padding="same", name="rpn_conv")(feature_extractor.output)
rpn_cls_output = Conv2D(hyper_params["anchor_count"], (1, 1), activation="sigmoid", name="rpn_cls")(output)
rpn_reg_output = Conv2D(hyper_params["anchor_count"] * 4, (1, 1), activation="linear", name="rpn_reg")(output)
rpn_model = Model(inputs=base_model.input, outputs=[rpn_reg_output, rpn_cls_output])
return rpn_model, feature_extractor
def init_model(model):
"""Initializing model with dummy data for load weights with optimizer state and also graph construction.
inputs:
model = tf.keras.model
"""
model(tf.random.uniform((1, 512, 512, 3)))