-
Notifications
You must be signed in to change notification settings - Fork 93
/
segformer_head.py
89 lines (71 loc) · 2.79 KB
/
segformer_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Obtained from: https://github.com/NVlabs/SegFormer
# Modifications: Model construction with loop
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
# A copy of the license is available at resources/license_segformer
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class MLP(nn.Module):
"""Linear Embedding."""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2).contiguous()
x = self.proj(x)
return x
@HEADS.register_module()
class SegFormerHead(BaseDecodeHead):
"""
SegFormer: Simple and Efficient Design for Semantic Segmentation with
Transformers
"""
def __init__(self, **kwargs):
super(SegFormerHead, self).__init__(
input_transform='multiple_select', **kwargs)
decoder_params = kwargs['decoder_params']
embedding_dim = decoder_params['embed_dim']
conv_kernel_size = decoder_params['conv_kernel_size']
self.linear_c = {}
for i, in_channels in zip(self.in_index, self.in_channels):
self.linear_c[str(i)] = MLP(
input_dim=in_channels, embed_dim=embedding_dim)
self.linear_c = nn.ModuleDict(self.linear_c)
self.linear_fuse = ConvModule(
in_channels=embedding_dim * len(self.in_index),
out_channels=embedding_dim,
kernel_size=conv_kernel_size,
padding=0 if conv_kernel_size == 1 else conv_kernel_size // 2,
norm_cfg=kwargs['norm_cfg'])
self.linear_pred = nn.Conv2d(
embedding_dim, self.num_classes, kernel_size=1)
def forward(self, inputs):
x = inputs
n, _, h, w = x[-1].shape
# for f in x:
# print(f.shape)
_c = {}
for i in self.in_index:
# mmcv.print_log(f'{i}: {x[i].shape}, {self.linear_c[str(i)]}')
_c[i] = self.linear_c[str(i)](x[i]).permute(0, 2, 1).contiguous()
_c[i] = _c[i].reshape(n, -1, x[i].shape[2], x[i].shape[3])
if i != 0:
_c[i] = resize(
_c[i],
size=x[0].size()[2:],
mode='bilinear',
align_corners=False)
_c = self.linear_fuse(torch.cat(list(_c.values()), dim=1))
if self.dropout is not None:
x = self.dropout(_c)
else:
x = _c
x = self.linear_pred(x)
return x