-
Notifications
You must be signed in to change notification settings - Fork 0
/
squeezenext_architecture.py
163 lines (137 loc) · 6.98 KB
/
squeezenext_architecture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from __future__ import absolute_import
import tensorflow as tf
slim = tf.contrib.slim
import tensorflow_extentions as tfe
def squeezenext_unit(inputs, filters, stride, height_first_order, groups, seperate_relus):
"""
Squeezenext unit according to:
https://arxiv.org/pdf/1803.10615.pdf
:param inputs:
Input tensor
:param filters:
Number of filters at output of this unit
:param stride:
Input stride
:param height_first_order:
Whether to first perform seperable convolution in the vertical direcation or horizontal direction
:param groups:
Number of groups for some of the convolutions (which ones are different from the paper but equal to:
https://github.com/amirgholami/SqueezeNext/blob/master/1.0-G-SqNxt-23/train_val.prototxt)
:return:
Output tensor, not(height_first_order)
"""
input_channels = inputs.get_shape().as_list()[-1]
shortcut = inputs
out_activation = tf.nn.relu if seperate_relus else None
# shorcut convolution only to be executed if input channels is different from output channels or
# stride is greater than 1.
if input_channels != filters or stride != 1:
shortcut = slim.conv2d(shortcut, filters, [1, 1], stride=stride, activation_fn=out_activation)
# input 1x1 reduction convolutions
block = tfe.grouped_convolution(inputs, filters / 2, [1, 1], groups, stride=stride)
block = slim.conv2d(block, block.get_shape().as_list()[-1] / 2, [1, 1])
# seperable convolutions
if height_first_order:
input_channels_seperated = block.get_shape().as_list()[-1]
block = tfe.grouped_convolution(block, input_channels_seperated * 2, [3, 1], groups)
block = tfe.grouped_convolution(block, block.get_shape().as_list()[-1], [1, 3], groups)
else:
input_channels_seperated = block.get_shape().as_list()[-1]
block = tfe.grouped_convolution(block, input_channels_seperated * 2, [1, 3], groups)
block = tfe.grouped_convolution(block, block.get_shape().as_list()[-1], [3, 1], groups)
# switch order next unit
height_first_order = not height_first_order
# output convolutions
block = slim.conv2d(block, block.get_shape().as_list()[-1] * 2, [1, 1], activation_fn=out_activation)
assert block.get_shape().as_list()[-1] == filters, "Block output channels not equal to number of specified filters"
return tf.nn.relu(block + shortcut), height_first_order
def arg_scope(is_training,
weight_decay=0.0001,
updates_collections=None):
"""
Setup slim arg scope according to paper and github project
:param is_training:
Whether or not the network is training
:param weight_decay:
Weight decay of the convolutional layers
:return:
Slim arg scope
"""
batch_norm_params = {
'is_training': is_training,
'center': True,
'scale': True,
'decay': 0.999,
'epsilon': 1e-5,
'fused': True,
"updates_collections": updates_collections if updates_collections is not None else tf.GraphKeys.UPDATE_OPS
}
# Use xavier an l2 decay
weights_init = tf.contrib.layers.xavier_initializer()
regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
with slim.arg_scope([slim.conv2d, tfe.grouped_convolution, slim.separable_conv2d],
weights_initializer=weights_init,
normalizer_fn=slim.batch_norm,
biases_initializer=None,
activation_fn=tf.nn.relu):
with slim.arg_scope([slim.conv2d, tfe.grouped_convolution],
weights_regularizer=regularizer):
with slim.arg_scope([slim.batch_norm], **batch_norm_params) as sc:
return sc
class SqueezeNext(object):
"""Base class for building the SqueezeNext Model."""
def __init__(self, num_classes, block_defs, input_def, groups, seperate_relus):
self.num_classes = num_classes
self.block_defs = block_defs
self.input_def = input_def
self.groups = groups
self.seperate_relus = seperate_relus
def __call__(self, inputs, height_first_order=True):
"""Add operations to classify a batch of input images.
Args:
inputs: A Tensor representing a batch of input images.
training: A boolean. Set to True to add operations required only when
training the classifier.
Returns:
A logits Tensor with shape [<batch_size>, self.num_classes].
"""
with tf.variable_scope("squeezenext"):
input_filters, input_kernel, input_stride = self.input_def
endpoints = {}
# input convolution and pooling
net = slim.conv2d(inputs, input_filters, input_kernel, stride=input_stride, scope="input_conv",
padding="VALID")
endpoints["input_conv"] = net
net = slim.max_pool2d(net, [3, 3], stride=2)
endpoints["max_pool"] = net
# create block based network
for block_idx, block_def in enumerate(self.block_defs):
filters, units, stride = block_def
with tf.variable_scope("block_{}".format(block_idx)):
# create seperate units inside a block
for unit_idx in range(0, units):
with tf.variable_scope("unit_{}".format(unit_idx)):
if unit_idx != 0:
# perform striding only in first unit of a block
net, height_first_order = squeezenext_unit(net, filters, 1, height_first_order,
self.groups, self.seperate_relus)
else:
net, height_first_order = squeezenext_unit(net, filters, stride, height_first_order,
self.groups, self.seperate_relus)
endpoints["block_{}".format(block_idx) + "/" + "unit_{}".format(unit_idx)] = net
# output conv and pooling
net = slim.conv2d(net, 128, [1, 1], scope="output_conv")
endpoints["output_conv"] = net
net = tf.squeeze(
slim.avg_pool2d(net, net.get_shape().as_list()[1:3], scope="avg_pool_out", padding="VALID"),
axis=[1, 2])
endpoints["avg_pool_out"] = net
# Fully connected output without biases
output = slim.fully_connected(net, self.num_classes, activation_fn=None, normalizer_fn=None,
biases_initializer=None)
endpoints["output"] = output
return output, endpoints
def model_arg_scope(self, is_training,
weight_decay=0.0001,
updates_collections=None):
return arg_scope(is_training,weight_decay=weight_decay,updates_collections=updates_collections)