-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
jpu.py
131 lines (117 loc) · 4.96 KB
/
jpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import NECKS
@NECKS.register_module()
class JPU(BaseModule):
"""FastFCN: Rethinking Dilated Convolution in the Backbone
for Semantic Segmentation.
This Joint Pyramid Upsampling (JPU) neck is the implementation of
`FastFCN <https://arxiv.org/abs/1903.11816>`_.
Args:
in_channels (Tuple[int], optional): The number of input channels
for each convolution operations before upsampling.
Default: (512, 1024, 2048).
mid_channels (int): The number of output channels of JPU.
Default: 512.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
dilations (tuple[int]): Dilation rate of each Depthwise
Separable ConvModule. Default: (1, 2, 4, 8).
align_corners (bool, optional): The align_corners argument of
resize operation. Default: False.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=(512, 1024, 2048),
mid_channels=512,
start_level=0,
end_level=-1,
dilations=(1, 2, 4, 8),
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(JPU, self).__init__(init_cfg=init_cfg)
assert isinstance(in_channels, tuple)
assert isinstance(dilations, tuple)
self.in_channels = in_channels
self.mid_channels = mid_channels
self.start_level = start_level
self.num_ins = len(in_channels)
if end_level == -1:
self.backbone_end_level = self.num_ins
else:
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
self.dilations = dilations
self.align_corners = align_corners
self.conv_layers = nn.ModuleList()
self.dilation_layers = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
conv_layer = nn.Sequential(
ConvModule(
self.in_channels[i],
self.mid_channels,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.conv_layers.append(conv_layer)
for i in range(len(dilations)):
dilation_layer = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=(self.backbone_end_level - self.start_level) *
self.mid_channels,
out_channels=self.mid_channels,
kernel_size=3,
stride=1,
padding=dilations[i],
dilation=dilations[i],
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=norm_cfg,
pw_act_cfg=act_cfg))
self.dilation_layers.append(dilation_layer)
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.in_channels), 'Length of inputs must \
be the same with self.in_channels!'
feats = [
self.conv_layers[i - self.start_level](inputs[i])
for i in range(self.start_level, self.backbone_end_level)
]
h, w = feats[0].shape[2:]
for i in range(1, len(feats)):
feats[i] = resize(
feats[i],
size=(h, w),
mode='bilinear',
align_corners=self.align_corners)
feat = torch.cat(feats, dim=1)
concat_feat = torch.cat([
self.dilation_layers[i](feat) for i in range(len(self.dilations))
],
dim=1)
outs = []
# Default: outs[2] is the output of JPU for decoder head, outs[1] is
# the feature map from backbone for auxiliary head. Additionally,
# outs[0] can also be used for auxiliary head.
for i in range(self.start_level, self.backbone_end_level - 1):
outs.append(inputs[i])
outs.append(concat_feat)
return tuple(outs)