-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
convnext.py
331 lines (280 loc) · 11.9 KB
/
convnext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from itertools import chain
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer,
build_norm_layer)
from mmcv.runner import BaseModule
from mmcv.runner.base_module import ModuleList, Sequential
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
@NORM_LAYERS.register_module('LN2d')
class LayerNorm2d(nn.LayerNorm):
"""LayerNorm on channels for 2d images.
Args:
num_channels (int): The number of channels of the input tensor.
eps (float): a value added to the denominator for numerical stability.
Defaults to 1e-5.
elementwise_affine (bool): a boolean value that when set to ``True``,
this module has learnable per-element affine parameters initialized
to ones (for weights) and zeros (for biases). Defaults to True.
"""
def __init__(self, num_channels: int, **kwargs) -> None:
super().__init__(num_channels, **kwargs)
self.num_channels = self.normalized_shape[0]
def forward(self, x):
assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \
f'(N, C, H, W), but got tensor with shape {x.shape}'
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight,
self.bias, self.eps).permute(0, 3, 1, 2)
class ConvNeXtBlock(BaseModule):
"""ConvNeXt Block.
Args:
in_channels (int): The number of input channels.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
mlp_ratio (float): The expansion ratio in both pointwise convolution.
Defaults to 4.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. More details can be found in the note.
Defaults to True.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
Note:
There are two equivalent implementations:
1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
all outputs are in (N, C, H, W).
2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU
-> Linear; Permute back
As default, we use the second to align with the official repository.
And it may be slightly faster.
"""
def __init__(self,
in_channels,
norm_cfg=dict(type='LN2d', eps=1e-6),
act_cfg=dict(type='GELU'),
mlp_ratio=4.,
linear_pw_conv=True,
drop_path_rate=0.,
layer_scale_init_value=1e-6):
super().__init__()
self.depthwise_conv = nn.Conv2d(
in_channels,
in_channels,
kernel_size=7,
padding=3,
groups=in_channels)
self.linear_pw_conv = linear_pw_conv
self.norm = build_norm_layer(norm_cfg, in_channels)[1]
mid_channels = int(mlp_ratio * in_channels)
if self.linear_pw_conv:
# Use linear layer to do pointwise conv.
pw_conv = nn.Linear
else:
pw_conv = partial(nn.Conv2d, kernel_size=1)
self.pointwise_conv1 = pw_conv(in_channels, mid_channels)
self.act = build_activation_layer(act_cfg)
self.pointwise_conv2 = pw_conv(mid_channels, in_channels)
self.gamma = nn.Parameter(
layer_scale_init_value * torch.ones((in_channels)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)
if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.linear_pw_conv:
x = x.permute(0, 3, 1, 2) # permute back
if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
x = shortcut + self.drop_path(x)
return x
@BACKBONES.register_module()
class ConvNeXt(BaseBackbone):
"""ConvNeXt.
A PyTorch implementation of : `A ConvNet for the 2020s
<https://arxiv.org/pdf/2201.03545.pdf>`_
Modified from the `official repo
<https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py>`_
and `timm
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py>`_.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
should include the following two keys:
- depths (list[int]): Number of blocks at each stage.
- channels (list[int]): The number of channels at each stage.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
stem_patch_size (int): The size of one patch in the stem layer.
Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. Defaults to True.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
gap_before_final_norm (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
init_cfg (dict, optional): Initialization config dict
""" # noqa: E501
arch_settings = {
'tiny': {
'depths': [3, 3, 9, 3],
'channels': [96, 192, 384, 768]
},
'small': {
'depths': [3, 3, 27, 3],
'channels': [96, 192, 384, 768]
},
'base': {
'depths': [3, 3, 27, 3],
'channels': [128, 256, 512, 1024]
},
'large': {
'depths': [3, 3, 27, 3],
'channels': [192, 384, 768, 1536]
},
'xlarge': {
'depths': [3, 3, 27, 3],
'channels': [256, 512, 1024, 2048]
},
}
def __init__(self,
arch='tiny',
in_channels=3,
stem_patch_size=4,
norm_cfg=dict(type='LN2d', eps=1e-6),
act_cfg=dict(type='GELU'),
linear_pw_conv=True,
drop_path_rate=0.,
layer_scale_init_value=1e-6,
out_indices=-1,
frozen_stages=0,
gap_before_final_norm=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'depths' in arch and 'channels' in arch, \
f'The arch dict must have "depths" and "channels", ' \
f'but got {list(arch.keys())}.'
self.depths = arch['depths']
self.channels = arch['channels']
assert (isinstance(self.depths, Sequence)
and isinstance(self.channels, Sequence)
and len(self.depths) == len(self.channels)), \
f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \
'should be both sequence with the same length.'
self.num_stages = len(self.depths)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 4 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.gap_before_final_norm = gap_before_final_norm
# stochastic depth decay rule
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
block_idx = 0
# 4 downsample layers between stages, including the stem layer.
self.downsample_layers = ModuleList()
stem = nn.Sequential(
nn.Conv2d(
in_channels,
self.channels[0],
kernel_size=stem_patch_size,
stride=stem_patch_size),
build_norm_layer(norm_cfg, self.channels[0])[1],
)
self.downsample_layers.append(stem)
# 4 feature resolution stages, each consisting of multiple residual
# blocks
self.stages = nn.ModuleList()
for i in range(self.num_stages):
depth = self.depths[i]
channels = self.channels[i]
if i >= 1:
downsample_layer = nn.Sequential(
LayerNorm2d(self.channels[i - 1]),
nn.Conv2d(
self.channels[i - 1],
channels,
kernel_size=2,
stride=2),
)
self.downsample_layers.append(downsample_layer)
stage = Sequential(*[
ConvNeXtBlock(
in_channels=channels,
drop_path_rate=dpr[block_idx + j],
norm_cfg=norm_cfg,
act_cfg=act_cfg,
linear_pw_conv=linear_pw_conv,
layer_scale_init_value=layer_scale_init_value)
for j in range(depth)
])
block_idx += depth
self.stages.append(stage)
if i in self.out_indices:
norm_layer = build_norm_layer(norm_cfg, channels)[1]
self.add_module(f'norm{i}', norm_layer)
self._freeze_stages()
def forward(self, x):
outs = []
for i, stage in enumerate(self.stages):
x = self.downsample_layers[i](x)
x = stage(x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
if self.gap_before_final_norm:
gap = x.mean([-2, -1], keepdim=True)
outs.append(norm_layer(gap).flatten(1))
else:
outs.append(norm_layer(x))
return tuple(outs)
def _freeze_stages(self):
for i in range(self.frozen_stages):
downsample_layer = self.downsample_layers[i]
stage = self.stages[i]
downsample_layer.eval()
stage.eval()
for param in chain(downsample_layer.parameters(),
stage.parameters()):
param.requires_grad = False
def train(self, mode=True):
super(ConvNeXt, self).train(mode)
self._freeze_stages()