-
Notifications
You must be signed in to change notification settings - Fork 509
/
densenet.py
executable file
·181 lines (147 loc) · 7.88 KB
/
densenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from typing import Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from collections import OrderedDict
from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.models import BaseClassifier
"""Densenet-BC model class, based on
"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Code source: https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
Performance reproducibility (4 GPUs):
training params: {"max_epochs": 120, "lr_updates": [30, 60, 90, 100, 110], "lr_decay_factor": 0.1, "initial_lr": 0.025}
dataset_params: {"batch_size": 64}
"""
class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.add_module("norm1", nn.BatchNorm2d(num_input_features)),
self.add_module("relu1", nn.ReLU(inplace=True)),
self.add_module("conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module("relu2", nn.ReLU(inplace=True)),
self.add_module("conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = float(drop_rate)
def bn_function(self, inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
return bottleneck_output
def forward(self, input): # noqa: F811
prev_features = [input] if isinstance(input, Tensor) else input
bottleneck_output = self.bn_function(prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features
class _DenseBlock(nn.ModuleDict):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
)
self.add_module("denselayer%d" % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.items():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module("norm", nn.BatchNorm2d(num_input_features))
self.add_module("relu", nn.ReLU(inplace=True))
self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet(BaseClassifier):
def __init__(self, growth_rate: int, structure: list, num_init_features: int, bn_size: int, drop_rate: float, num_classes: int, in_channels: int = 3):
"""
:param growth_rate: number of filter to add each layer (noted as 'k' in the paper)
:param structure: how many layers in each pooling block - sequentially
:param num_init_features: the number of filters to learn in the first convolutional layer
:param bn_size: multiplicative factor for the number of bottle neck layers
(i.e. bn_size * k featurs in the bottleneck)
:param drop_rate: dropout rate after each dense layer
:param num_classes: number of classes in the classification task
:param in_channels: number of channels in the input image
"""
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(
OrderedDict(
[
("conv0", nn.Conv2d(in_channels, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
("norm0", nn.BatchNorm2d(num_init_features)),
("relu0", nn.ReLU(inplace=True)),
("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]
)
)
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(structure):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
self.features.add_module("denseblock%d" % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(structure) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
self.features.add_module("transition%d" % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module("norm5", nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.classifier(out)
return out
def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
from super_gradients.modules.weight_replacement_utils import replace_conv2d_input_channels
self.features[0] = replace_conv2d_input_channels(conv=self.features[0], in_channels=in_channels, fn=compute_new_weights_fn)
def get_input_channels(self) -> int:
return self.features[0].in_channels
@register_model(Models.CUSTOM_DENSENET)
class CustomizedDensnet(DenseNet):
def __init__(self, arch_params):
super().__init__(
growth_rate=arch_params.growth_rate if hasattr(arch_params, "growth_rate") else 32,
structure=arch_params.structure if hasattr(arch_params, "structure") else [6, 12, 24, 16],
num_init_features=arch_params.num_init_features if hasattr(arch_params, "num_init_features") else 64,
bn_size=arch_params.bn_size if hasattr(arch_params, "bn_size") else 4,
drop_rate=arch_params.drop_rate if hasattr(arch_params, "drop_rate") else 0,
num_classes=arch_params.num_classes,
)
@register_model(Models.DENSENET121)
class DenseNet121(DenseNet):
def __init__(self, arch_params):
super().__init__(32, [6, 12, 24, 16], 64, 4, 0, arch_params.num_classes)
@register_model(Models.DENSENET161)
class DenseNet161(DenseNet):
def __init__(self, arch_params):
super().__init__(48, [6, 12, 36, 24], 96, 4, 0, arch_params.num_classes)
@register_model(Models.DENSENET169)
class DenseNet169(DenseNet):
def __init__(self, arch_params):
super().__init__(32, [6, 12, 32, 32], 64, 4, 0, arch_params.num_classes)
@register_model(Models.DENSENET201)
class DenseNet201(DenseNet):
def __init__(self, arch_params):
super().__init__(32, [6, 12, 48, 32], 64, 4, 0, arch_params.num_classes)