Skip to content

Commit

Permalink
[Feature] Add ESPNetV1 (#1625)
Browse files Browse the repository at this point in the history
  • Loading branch information
simuler authored Dec 28, 2021
1 parent aad2c1c commit bfca53f
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 0 deletions.
14 changes: 14 additions & 0 deletions configs/espnetv1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation


## Reference

> Mehta Sachin, Mohammad Rastegari, Anat Caspi, Linda Shapiro, and Hannaneh Hajishirzi. "ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation.".In Proceedings of the European Conference on Computer Vision, pp. 552-568. 2018.
## Performance

### Cityscapes

| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|ESPNetV2|-|1024x512|120000|61.82%|62.20%|62.89%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/espnetv1_cityscapes_1024x512_120k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/espnetv1_cityscapes_1024x512_120k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=472e91a0600420c99a0dc3a1e6f80f87)
28 changes: 28 additions & 0 deletions configs/espnetv1/espnetv1_cityscapes_1024x512_120k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_base_: '../_base_/cityscapes.yml'

batch_size: 4
iters: 120000

optimizer:
_inherited_: False
type: adam
weight_decay: 0.0002

lr_scheduler:
type: PolynomialDecay
learning_rate: 0.001
end_lr: 0.0
power: 0.9

loss:
types:
- type: CrossEntropyLoss
weight: [2.79834108 ,6.92945723 ,3.84068512 ,9.94349362 ,9.77098823 ,9.51484 ,10.30981624 ,9.94307377 ,4.64933892 ,9.55759938 ,7.86692178 ,9.53126629 ,10.3496365 ,6.67234062 ,10.26054204 ,10.28785275 ,10.28988296 ,10.40546021 ,10.13848367]
coef: [1]

model:
type: ESPNetV1
in_channels: 3
num_classes: 19
level2_depth: 2
level3_depth: 8
1 change: 1 addition & 0 deletions paddleseg/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@
from .hrnet_contrast import HRNetW48Contrast
from .espnet import ESPNetV2
from .dmnet import DMNet
from .espnetv1 import ESPNetV1
308 changes: 308 additions & 0 deletions paddleseg/models/espnetv1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddleseg.models import layers
from paddleseg.cvlibs import manager
from paddleseg.utils import utils


@manager.MODELS.add_component
class ESPNetV1(nn.Layer):
"""
The ESPNetV1 implementation based on PaddlePaddle.
The original article refers to
Sachin Mehta1, Mohammad Rastegari, Anat Caspi, Linda Shapiro, and Hannaneh Hajishirzi. "ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation"
(https://arxiv.org/abs/1803.06815).
Args:
num_classes (int): The unique number of target classes.
in_channels (int, optional): Number of input channels. Default: 3.
level2_depth (int, optional): Depth of DilatedResidualBlock. Default: 2.
level3_depth (int, optional): Depth of DilatedResidualBlock. Default: 3.
pretrained (str, optional): The path or url of pretrained model. Default: None.
"""
def __init__(self,
num_classes,
in_channels=3,
level2_depth=2,
level3_depth=3,
pretrained=None):
super().__init__()
self.encoder = ESPNetEncoder(num_classes, in_channels, level2_depth,
level3_depth)

self.level3_up = nn.Conv2DTranspose(num_classes,
num_classes,
2,
stride=2,
padding=0,
output_padding=0,
bias_attr=False)
self.br3 = layers.SyncBatchNorm(num_classes)
self.level2_proj = nn.Conv2D(in_channels + 128,
num_classes,
1,
bias_attr=False)
self.combine_l2_l3 = nn.Sequential(
BNPReLU(2 * num_classes),
DilatedResidualBlock(2 * num_classes, num_classes, residual=False),
)
self.level2_up = nn.Sequential(
nn.Conv2DTranspose(num_classes,
num_classes,
2,
stride=2,
padding=0,
output_padding=0,
bias_attr=False),
BNPReLU(num_classes),
)
self.out_proj = layers.ConvBNPReLU(16 + in_channels + num_classes,
num_classes,
3,
padding='same',
stride=1)
self.out_up = nn.Conv2DTranspose(num_classes,
num_classes,
2,
stride=2,
padding=0,
output_padding=0,
bias_attr=False)
self.pretrained = pretrained

def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)

def forward(self, x):
p1, p2, p3 = self.encoder(x)
up_p3 = self.level3_up(p3)

combine = self.combine_l2_l3(paddle.concat([up_p3, p2], axis=1))
up_p2 = self.level2_up(combine)

combine = self.out_proj(paddle.concat([up_p2, p1], axis=1))
out = self.out_up(combine)
return [out]


class BNPReLU(nn.Layer):
def __init__(self, channels):
super().__init__()
self.bn = layers.SyncBatchNorm(channels)
self.act = nn.PReLU(channels)

def forward(self, x):
x = self.bn(x)
x = self.act(x)
return x


class DownSampler(nn.Layer):
"""
Down sampler.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
branch_channels = out_channels // 5
remain_channels = out_channels - branch_channels * 4
self.conv1 = nn.Conv2D(in_channels,
branch_channels,
3,
stride=2,
padding=1,
bias_attr=False)
self.d_conv1 = nn.Conv2D(branch_channels,
remain_channels,
3,
padding=1,
bias_attr=False)
self.d_conv2 = nn.Conv2D(branch_channels,
branch_channels,
3,
padding=2,
dilation=2,
bias_attr=False)
self.d_conv4 = nn.Conv2D(branch_channels,
branch_channels,
3,
padding=4,
dilation=4,
bias_attr=False)
self.d_conv8 = nn.Conv2D(branch_channels,
branch_channels,
3,
padding=8,
dilation=8,
bias_attr=False)
self.d_conv16 = nn.Conv2D(branch_channels,
branch_channels,
3,
padding=16,
dilation=16,
bias_attr=False)
self.bn = layers.SyncBatchNorm(out_channels)
self.act = nn.PReLU(out_channels)

def forward(self, x):
x = self.conv1(x)
d1 = self.d_conv1(x)
d2 = self.d_conv2(x)
d4 = self.d_conv4(x)
d8 = self.d_conv8(x)
d16 = self.d_conv16(x)

feat1 = d2
feat2 = feat1 + d4
feat3 = feat2 + d8
feat4 = feat3 + d16

feat = paddle.concat([d1, feat1, feat2, feat3, feat4], axis=1)
out = self.bn(feat)
out = self.act(out)
return out


class DilatedResidualBlock(nn.Layer):
'''
ESP block, principle: reduce -> split -> transform -> merge
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
residual (bool, optional): Add a residual connection through identity operation. Default: True.
'''
def __init__(self, in_channels, out_channels, residual=True):
super().__init__()
branch_channels = out_channels // 5
remain_channels = out_channels - branch_channels * 4
self.conv1 = nn.Conv2D(in_channels, branch_channels, 1, bias_attr=False)
self.d_conv1 = nn.Conv2D(branch_channels,
remain_channels,
3,
padding=1,
bias_attr=False)
self.d_conv2 = nn.Conv2D(branch_channels,
branch_channels,
3,
padding=2,
dilation=2,
bias_attr=False)
self.d_conv4 = nn.Conv2D(branch_channels,
branch_channels,
3,
padding=4,
dilation=4,
bias_attr=False)
self.d_conv8 = nn.Conv2D(branch_channels,
branch_channels,
3,
padding=8,
dilation=8,
bias_attr=False)
self.d_conv16 = nn.Conv2D(branch_channels,
branch_channels,
3,
padding=16,
dilation=16,
bias_attr=False)

self.bn = BNPReLU(out_channels)
self.residual = residual

def forward(self, x):
x_proj = self.conv1(x)
d1 = self.d_conv1(x_proj)
d2 = self.d_conv2(x_proj)
d4 = self.d_conv4(x_proj)
d8 = self.d_conv8(x_proj)
d16 = self.d_conv16(x_proj)

feat1 = d2
feat2 = feat1 + d4
feat3 = feat2 + d8
feat4 = feat3 + d16

feat = paddle.concat([d1, feat1, feat2, feat3, feat4], axis=1)

if self.residual:
feat = feat + x
out = self.bn(feat)
return out


class ESPNetEncoder(nn.Layer):
'''
The ESPNet-C implementation based on PaddlePaddle.
Args:
num_classes (int): The unique number of target classes.
in_channels (int, optional): Number of input channels. Default: 3.
level2_depth (int, optional): Depth of DilatedResidualBlock. Default: 5.
level3_depth (int, optional): Depth of DilatedResidualBlock. Default: 3.
'''
def __init__(self,
num_classes,
in_channels=3,
level2_depth=5,
level3_depth=3):
super().__init__()
self.level1 = layers.ConvBNPReLU(in_channels,
16,
3,
padding='same',
stride=2)
self.br1 = BNPReLU(in_channels + 16)
self.proj1 = layers.ConvBNPReLU(in_channels + 16, num_classes, 1)

self.level2_0 = DownSampler(in_channels + 16, 64)
self.level2 = nn.Sequential(
*[DilatedResidualBlock(64, 64) for i in range(level2_depth)])
self.br2 = BNPReLU(in_channels + 128)
self.proj2 = layers.ConvBNPReLU(in_channels + 128, num_classes, 1)

self.level3_0 = DownSampler(in_channels + 128, 128)
self.level3 = nn.Sequential(
*[DilatedResidualBlock(128, 128) for i in range(level3_depth)])
self.br3 = BNPReLU(256)
self.proj3 = layers.ConvBNPReLU(256, num_classes, 1)

def forward(self, x):
f1 = self.level1(x)
down2 = F.adaptive_avg_pool2d(x, output_size=f1.shape[2:])
feat1 = paddle.concat([f1, down2], axis=1)
feat1 = self.br1(feat1)
p1 = self.proj1(feat1)

f2_res = self.level2_0(feat1)
f2 = self.level2(f2_res)
down4 = F.adaptive_avg_pool2d(x, output_size=f2.shape[2:])
feat2 = paddle.concat([f2, f2_res, down4], axis=1)
feat2 = self.br2(feat2)
p2 = self.proj2(feat2)

f3_res = self.level3_0(feat2)
f3 = self.level3(f3_res)
feat3 = paddle.concat([f3, f3_res], axis=1)
feat3 = self.br3(feat3)
p3 = self.proj3(feat3)

return p1, p2, p3

0 comments on commit bfca53f

Please sign in to comment.