-
Notifications
You must be signed in to change notification settings - Fork 509
/
pnasnet.py
executable file
·130 lines (102 loc) · 4.24 KB
/
pnasnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""PNASNet in PyTorch.
Paper: Progressive Neural Architecture Search
https://github.com/kuangliu/pytorch-cifar/blob/master/models/pnasnet.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from super_gradients.training.models import BaseClassifier
class SepConv(nn.Module):
"""Separable Convolution."""
def __init__(self, in_planes, out_planes, kernel_size, stride):
super(SepConv, self).__init__()
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=False, groups=in_planes)
self.bn1 = nn.BatchNorm2d(out_planes)
def forward(self, x):
return self.bn1(self.conv1(x))
class CellA(nn.Module):
def __init__(self, in_planes, out_planes, stride=1):
super(CellA, self).__init__()
self.stride = stride
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
if stride == 2:
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(out_planes)
def forward(self, x):
y1 = self.sep_conv1(x)
y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
if self.stride == 2:
y2 = self.bn1(self.conv1(y2))
return F.relu(y1 + y2)
class CellB(nn.Module):
def __init__(self, in_planes, out_planes, stride=1):
super(CellB, self).__init__()
self.stride = stride
# Left branch
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride)
# Right branch
self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride)
if stride == 2:
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(out_planes)
# Reduce channels
self.conv2 = nn.Conv2d(2 * out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn2 = nn.BatchNorm2d(out_planes)
def forward(self, x):
# Left branch
y1 = self.sep_conv1(x)
y2 = self.sep_conv2(x)
# Right branch
y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
if self.stride == 2:
y3 = self.bn1(self.conv1(y3))
y4 = self.sep_conv3(x)
# Concat & reduce channels
b1 = F.relu(y1 + y2)
b2 = F.relu(y3 + y4)
y = torch.cat([b1, b2], 1)
return F.relu(self.bn2(self.conv2(y)))
class PNASNet(BaseClassifier):
def __init__(self, cell_type, num_cells, num_planes):
super(PNASNet, self).__init__()
self.in_planes = num_planes
self.cell_type = cell_type
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(num_planes)
self.layer1 = self._make_layer(num_planes, num_cells=6)
self.layer2 = self._downsample(num_planes * 2)
self.layer3 = self._make_layer(num_planes * 2, num_cells=6)
self.layer4 = self._downsample(num_planes * 4)
self.layer5 = self._make_layer(num_planes * 4, num_cells=6)
self.linear = nn.Linear(num_planes * 4, 10)
def _make_layer(self, planes, num_cells):
layers = []
for _ in range(num_cells):
layers.append(self.cell_type(self.in_planes, planes, stride=1))
self.in_planes = planes
return nn.Sequential(*layers)
def _downsample(self, planes):
layer = self.cell_type(self.in_planes, planes, stride=2)
self.in_planes = planes
return layer
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = F.avg_pool2d(out, 8)
out = self.linear(out.view(out.size(0), -1))
return out
def PNASNetA():
return PNASNet(CellA, num_cells=6, num_planes=44)
def PNASNetB():
return PNASNet(CellB, num_cells=6, num_planes=32)
def test():
net = PNASNetB()
x = torch.randn(1, 3, 32, 32)
y = net(x)
print(y)
# test()