-
Notifications
You must be signed in to change notification settings - Fork 83
/
model.py
198 lines (164 loc) · 6.93 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class LinearBottleneck(nn.Module):
def __init__(self, inplanes, outplanes, stride=1, t=6, activation=nn.ReLU6):
super(LinearBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, inplanes * t, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(inplanes * t)
self.conv2 = nn.Conv2d(inplanes * t, inplanes * t, kernel_size=3, stride=stride, padding=1, bias=False,
groups=inplanes * t)
self.bn2 = nn.BatchNorm2d(inplanes * t)
self.conv3 = nn.Conv2d(inplanes * t, outplanes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(outplanes)
self.activation = activation(inplace=True)
self.stride = stride
self.t = t
self.inplanes = inplanes
self.outplanes = outplanes
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.activation(out)
out = self.conv3(out)
out = self.bn3(out)
if self.stride == 1 and self.inplanes == self.outplanes:
out += residual
return out
class MobileNet2(nn.Module):
"""MobileNet2 implementation.
"""
def __init__(self, scale=1.0, input_size=224, t=6, in_channels=3, num_classes=1000, activation=nn.ReLU6):
"""
MobileNet2 constructor.
:param in_channels: (int, optional): number of channels in the input tensor.
Default is 3 for RGB image inputs.
:param input_size:
:param num_classes: number of classes to predict. Default
is 1000 for ImageNet.
:param scale:
:param t:
:param activation:
"""
super(MobileNet2, self).__init__()
self.scale = scale
self.t = t
self.activation_type = activation
self.activation = activation(inplace=True)
self.num_classes = num_classes
self.num_of_channels = [32, 16, 24, 32, 64, 96, 160, 320]
# assert (input_size % 32 == 0)
self.c = [_make_divisible(ch * self.scale, 8) for ch in self.num_of_channels]
self.n = [1, 1, 2, 3, 4, 3, 3, 1]
self.s = [2, 1, 2, 2, 2, 1, 2, 1]
self.conv1 = nn.Conv2d(in_channels, self.c[0], kernel_size=3, bias=False, stride=self.s[0], padding=1)
self.bn1 = nn.BatchNorm2d(self.c[0])
self.bottlenecks = self._make_bottlenecks()
# Last convolution has 1280 output channels for scale <= 1
self.last_conv_out_ch = 1280 if self.scale <= 1 else _make_divisible(1280 * self.scale, 8)
self.conv_last = nn.Conv2d(self.c[-1], self.last_conv_out_ch, kernel_size=1, bias=False)
self.bn_last = nn.BatchNorm2d(self.last_conv_out_ch)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=0.2, inplace=True) # confirmed by paper authors
self.fc = nn.Linear(self.last_conv_out_ch, self.num_classes)
self.init_params()
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def _make_stage(self, inplanes, outplanes, n, stride, t, stage):
modules = OrderedDict()
stage_name = "LinearBottleneck{}".format(stage)
# First module is the only one utilizing stride
first_module = LinearBottleneck(inplanes=inplanes, outplanes=outplanes, stride=stride, t=t,
activation=self.activation_type)
modules[stage_name + "_0"] = first_module
# add more LinearBottleneck depending on number of repeats
for i in range(n - 1):
name = stage_name + "_{}".format(i + 1)
module = LinearBottleneck(inplanes=outplanes, outplanes=outplanes, stride=1, t=6,
activation=self.activation_type)
modules[name] = module
return nn.Sequential(modules)
def _make_bottlenecks(self):
modules = OrderedDict()
stage_name = "Bottlenecks"
# First module is the only one with t=1
bottleneck1 = self._make_stage(inplanes=self.c[0], outplanes=self.c[1], n=self.n[1], stride=self.s[1], t=1,
stage=0)
modules[stage_name + "_0"] = bottleneck1
# add more LinearBottleneck depending on number of repeats
for i in range(1, len(self.c) - 1):
name = stage_name + "_{}".format(i)
module = self._make_stage(inplanes=self.c[i], outplanes=self.c[i + 1], n=self.n[i + 1],
stride=self.s[i + 1],
t=self.t, stage=i)
modules[name] = module
return nn.Sequential(modules)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.activation(x)
x = self.bottlenecks(x)
x = self.conv_last(x)
x = self.bn_last(x)
x = self.activation(x)
# average pooling layer
x = self.avgpool(x)
x = self.dropout(x)
# flatten for input to fully-connected layer
x = x.view(x.size(0), -1)
x = self.fc(x)
return F.log_softmax(x, dim=1) #TODO not needed(?)
if __name__ == "__main__":
"""Testing
"""
model1 = MobileNet2()
print(model1)
model2 = MobileNet2(scale=0.35)
print(model2)
model3 = MobileNet2(in_channels=2, num_classes=10)
print(model3)
x = torch.randn(1, 2, 224, 224)
print(model3(x))
model4_size = 32 * 10
model4 = MobileNet2(input_size=model4_size, num_classes=10)
print(model4)
x2 = torch.randn(1, 3, model4_size, model4_size)
print(model4(x2))
model5 = MobileNet2(input_size=196, num_classes=10)
x3 = torch.randn(1, 3, 196, 196)
print(model5(x3)) # fail