-
Notifications
You must be signed in to change notification settings - Fork 26
/
pnasnet.py
125 lines (103 loc) · 4.16 KB
/
pnasnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
'''PNASNet in PyTorch.
Paper: Progressive Neural Architecture Search
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
class SepConv(nn.Module):
'''Separable Convolution.'''
def __init__(self, in_planes, out_planes, kernel_size, stride):
super(SepConv, self).__init__()
self.conv1 = nn.Conv2d(in_planes, out_planes,
kernel_size, stride,
padding=(kernel_size-1)//2,
bias=False, groups=in_planes)
self.bn1 = nn.BatchNorm2d(out_planes)
def forward(self, x):
return self.bn1(self.conv1(x))
class CellA(nn.Module):
def __init__(self, in_planes, out_planes, stride=1):
super(CellA, self).__init__()
self.stride = stride
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
if stride==2:
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(out_planes)
def forward(self, x):
y1 = self.sep_conv1(x)
y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
if self.stride==2:
y2 = self.bn1(self.conv1(y2))
return F.relu(y1+y2)
class CellB(nn.Module):
def __init__(self, in_planes, out_planes, stride=1):
super(CellB, self).__init__()
self.stride = stride
# Left branch
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride)
# Right branch
self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride)
if stride==2:
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(out_planes)
# Reduce channels
self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn2 = nn.BatchNorm2d(out_planes)
def forward(self, x):
# Left branch
y1 = self.sep_conv1(x)
y2 = self.sep_conv2(x)
# Right branch
y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
if self.stride==2:
y3 = self.bn1(self.conv1(y3))
y4 = self.sep_conv3(x)
# Concat & reduce channels
b1 = F.relu(y1+y2)
b2 = F.relu(y3+y4)
y = torch.cat([b1,b2], 1)
return F.relu(self.bn2(self.conv2(y)))
class PNASNet(nn.Module):
def __init__(self, cell_type, num_cells, num_planes):
super(PNASNet, self).__init__()
self.in_planes = num_planes
self.cell_type = cell_type
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(num_planes)
self.layer1 = self._make_layer(num_planes, num_cells=6)
self.layer2 = self._downsample(num_planes*2)
self.layer3 = self._make_layer(num_planes*2, num_cells=6)
self.layer4 = self._downsample(num_planes*4)
self.layer5 = self._make_layer(num_planes*4, num_cells=6)
self.linear = nn.Linear(num_planes*4, 10)
def _make_layer(self, planes, num_cells):
layers = []
for _ in range(num_cells):
layers.append(self.cell_type(self.in_planes, planes, stride=1))
self.in_planes = planes
return nn.Sequential(*layers)
def _downsample(self, planes):
layer = self.cell_type(self.in_planes, planes, stride=2)
self.in_planes = planes
return layer
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = F.avg_pool2d(out, 8)
out = self.linear(out.view(out.size(0), -1))
return out
def PNASNetA():
return PNASNet(CellA, num_cells=6, num_planes=44)
def PNASNetB():
return PNASNet(CellB, num_cells=6, num_planes=32)
def test():
net = PNASNetB()
x = torch.randn(1,3,32,32)
y = net(x)
print(y)
# test()